Ejemplo n.º 1
0
class EpaClassifier:
    def __init__(self, config, args):
        self.cfg = config
        self.jplace_fname = args.jplace_fname
        self.ignore_refalign = args.ignore_refalign

        self.tmp_refaln = config.tmp_fname("%NAME%.refaln")
        #here is the final alignment file for running EPA
        self.epa_alignment = config.tmp_fname("%NAME%.afa")
        self.hmmprofile = config.tmp_fname("%NAME%.hmmprofile")
        self.tmpquery = config.tmp_fname("%NAME%.tmpquery")
        self.noalign = config.tmp_fname("%NAME%.noalign")
        self.seqs = None

        assign_fname = args.output_name + ".assignment.txt"
        self.out_assign_fname = os.path.join(args.output_dir, assign_fname)
        jplace_fname = args.output_name + ".jplace"
        self.out_jplace_fname = os.path.join(args.output_dir, jplace_fname)

        try:
            self.refjson = RefJsonParser(config.refjson_fname)
        except ValueError:
            self.cfg.exit_user_error("Invalid json file format: %s" %
                                     config.refjson_fname)
        #validate input json format
        self.refjson.validate()
        self.reftree = self.refjson.get_reftree()
        self.rate = self.refjson.get_rate()
        self.node_height = self.refjson.get_node_height()
        self.cfg.compress_patterns = self.refjson.get_pattern_compression()

        self.bid_taxonomy_map = self.refjson.get_branch_tax_map()
        if not self.bid_taxonomy_map:
            # old file format (before 1.6), need to rebuild this map from scratch
            th = TaxTreeHelper(self.cfg, self.refjson.get_origin_taxonomy())
            th.set_mf_rooted_tree(self.refjson.get_tax_tree())
            th.set_bf_unrooted_tree(self.refjson.get_reftree())
            self.bid_taxonomy_map = th.get_bid_taxonomy_map()

        self.cfg.log.info("Loaded reference tree with %d taxa\n" %
                          len(self.reftree.get_leaves()))

        self.classify_helper = TaxClassifyHelper(self.cfg,
                                                 self.bid_taxonomy_map,
                                                 self.rate, self.node_height)

    def require_muscle(self):
        basepath = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(basepath + "/epac/bin/muscle"):
            errmsg = "The pipeline uses MUSCLE to merge alignments, please download the programm from:\n" + \
                     "http://www.drive5.com/muscle/downloads.htm\n" + \
                     "and specify path to your installation in the config file (sativa.cfg)\n"
            self.cfg.exit_user_error(errmsg)

    def require_hmmer(self):
        basepath = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(basepath +
                              "/epac/bin/hmmbuild") or not os.path.exists(
                                  basepath + "/epac/bin/hmmalign"):
            errmsg = "The pipeline uses HAMMER to align the query seqeunces, please download the programm from:\n" + \
                     "http://hmmer.janelia.org/\n" + \
                     "and specify path to your installation in the config file (sativa.cfg)\n"
            self.cfg.exit_user_error(errmsg)

    def align_to_refenence(self, noalign, minp=0.9):
        refaln = self.refjson.get_alignment(fout=self.tmp_refaln)
        fprofile = self.refjson.get_hmm_profile(self.hmmprofile)

        # if there is no hmmer profile in json file, build it from scratch
        if not fprofile:
            hmm = hmmer(self.cfg, refaln)
            fprofile = hmm.build_hmm_profile()

        hm = hmmer(config=self.cfg,
                   refalign=refaln,
                   query=self.tmpquery,
                   refprofile=fprofile,
                   discard=noalign,
                   seqs=self.seqs,
                   minp=minp)
        self.epa_alignment = hm.align()

    def merge_alignment(self, query_seqs):
        refaln = self.refjson.get_alignment_list()
        with open(self.epa_alignment, "w") as fout:
            for seq in refaln:
                fout.write(">" + seq[0] + "\n" + seq[1] + "\n")
            for name, seq, comment, sid in query_seqs.iter_entries():
                fout.write(">" + name + "\n" + seq + "\n")

    def write_combined_alignment(self):
        self.query_count = 0
        with open(self.epa_alignment, "w") as fout:
            for name, seq, comment, sid in self.seqs.iter_entries():
                ref_name = self.refjson.get_corr_seqid(
                    EpacConfig.REF_SEQ_PREFIX + name)
                if ref_name in self.refjson.get_sequences_names():
                    seq_name = ref_name
                else:
                    seq_name = EpacConfig.QUERY_SEQ_PREFIX + name
                    self.query_count += 1
                fout.write(">" + seq_name + "\n" + seq + "\n")

    def checkinput(self, query_fname, minp=0.9):
        formats = [
            "fasta", "phylip", "iphylip", "phylip_relaxed", "iphylip_relaxed"
        ]
        for fmt in formats:
            try:
                self.seqs = SeqGroup(sequences=query_fname, format=fmt)
                break
            except:
                self.cfg.log.debug("Guessing input format: not " + fmt)
        if self.seqs == None:
            self.cfg.exit_user_error(
                "Invalid input file format: %s\nThe supported input formats are fasta and phylip"
                % query_fname)

        if self.ignore_refalign:
            self.cfg.log.info(
                "Assuming query file contains reference sequences, skipping the alignment step...\n"
            )
            self.write_combined_alignment()
            return

        self.query_count = len(self.seqs)

        # add query seq name prefix to avoid confusion between reference and query sequences
        self.seqs.add_name_prefix(EpacConfig.QUERY_SEQ_PREFIX)

        self.seqs.write(format="fasta", outfile=self.tmpquery)
        self.cfg.log.info("Checking if query sequences are aligned ...")
        entries = self.seqs.get_entries()
        seql = len(entries[0][1])
        aligned = True
        for entri in entries[1:]:
            l = len(entri[1])
            if not seql == l:
                aligned = False
                break

        if aligned and len(self.seqs) > 1:
            self.cfg.log.info("Query sequences are aligned")
            refalnl = self.refjson.get_alignment_length()
            if refalnl == seql:
                self.cfg.log.info(
                    "Merging query alignment with reference alignment")
                self.merge_alignment(self.seqs)
            else:
                self.cfg.log.info(
                    "Merging query alignment with reference alignment using MUSCLE"
                )
                self.require_muscle()
                refaln = self.refjson.get_alignment(fout=self.tmp_refaln)
                m = muscle(self.cfg)
                self.epa_alignment = m.merge(refaln, self.tmpquery)
        else:
            self.cfg.log.info("Query sequences are not aligned")
            self.cfg.log.info(
                "Align query sequences to the reference alignment using HMMER")
            self.require_hmmer()
            self.align_to_refenence(self.noalign, minp=minp)

    def print_ranks(self, rks, confs, minlw=0.0):
        uncorr_ranks = self.refjson.get_uncorr_ranks(rks)
        ss = ""
        css = ""
        for i in range(len(uncorr_ranks)):
            conf = confs[i]
            if conf == confs[0] and confs[0] >= 0.99:
                conf = 1.0
            if conf >= minlw:
                ss = ss + uncorr_ranks[i] + ";"
                css = css + "{0:.3f}".format(conf) + ";"
            else:
                break
        if ss == "":
            return None
        else:
            return ss[:-1] + "\t" + css[:-1]

    def run_epa(self):
        self.cfg.log.info(
            "Running RAxML-EPA to place %d query sequences...\n" %
            self.query_count)
        raxml = RaxmlWrapper(config)
        reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre")
        self.refjson.get_raxml_readable_tree(reftree_fname)
        optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
        self.refjson.get_binary_model(optmod_fname)
        job_name = self.cfg.subst_name("epa_%NAME%")

        reftree_str = self.refjson.get_raxml_readable_tree()
        reftree = Tree(reftree_str)

        self.reftree_size = len(reftree.get_leaves())

        # IMPORTANT: set EPA heuristic rate based on tree size!
        self.cfg.resolve_auto_settings(self.reftree_size)
        # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file
        if self.cfg.epa_load_optmod:
            self.cfg.raxml_model = self.refjson.get_ratehet_model()

        reduced_align_fname = raxml.reduce_alignment(self.epa_alignment)

        jp = raxml.run_epa(job_name, reduced_align_fname, reftree_fname,
                           optmod_fname)

        raxml.copy_epa_jplace(job_name, self.out_jplace_fname, move=True)

        return jp

    def run_ptp(self, jp):
        full_aln = SeqGroup(self.epa_alignment)
        species_list = epa_2_ptp(epa_jp=jp,
                                 ref_jp=self.refjson,
                                 full_alignment=full_aln,
                                 min_lw=0.5,
                                 debug=self.cfg.debug)

        self.cfg.log.debug("Species clusters:")

        if fout:
            fo2 = open(fout + ".species", "w")
        else:
            fo2 = None

        for sp_cluster in species_list:
            translated_taxa = []
            for taxon in sp_cluster:
                origin_taxon_name = EpacConfig.strip_query_prefix(taxon)
                translated_taxa.append(origin_taxon_name)
            s = ",".join(translated_taxa)
            if fo2:
                fo2.write(s + "\n")
            self.cfg.log.debug(s)

        if fo2:
            fo2.close()

    def print_result_line(self, fo, line):
        if self.cfg.verbose:
            print(line)
        if fo:
            fo.write(line + "\n")

    def get_noalign_list(self):
        noalign_list = []
        if os.path.exists(self.noalign):
            with open(self.noalign) as fnoa:
                lines = fnoa.readlines()
                for line in lines:
                    taxon_name = line.strip()[1:]
                    origin_taxon_name = EpacConfig.strip_query_prefix(
                        taxon_name)
                    noalign_list.append(origin_taxon_name)
        return noalign_list

    def classify(self, query_fname, minp=0.9, ptp=False):
        if self.jplace_fname:
            jp = EpaJsonParser(self.jplace_fname)
        else:
            self.checkinput(query_fname, minp)
            jp = self.run_epa()

        self.cfg.log.info(
            "Assigning taxonomic labels based on EPA placements...\n")

        placements = jp.get_placement()

        if self.out_assign_fname:
            fo = open(self.out_assign_fname, "w")
        else:
            fo = None

        noassign_list = []
        for place in placements:
            taxon_name = place["n"][0]
            origin_taxon_name = EpacConfig.strip_query_prefix(taxon_name)
            edges = place["p"]

            ranks, lws = self.classify_helper.classify_seq(edges)
            rankout = self.print_ranks(ranks, lws, self.cfg.min_lhw)

            if rankout == None:
                noassign_list.append(origin_taxon_name)
            else:
                output = "%s\t%s\t" % (origin_taxon_name, rankout)
                if self.cfg.check_novelty:
                    isnovo = self.novelty_check(place_edge=str(edges[0][0]),
                                                ranks=ranks,
                                                lws=lws)
                    output += "*" if isnovo else "o"
                self.print_result_line(fo, output)

        noassign_list += self.get_noalign_list()

        for taxon_name in noassign_list:
            output = "%s\t\t\t?" % origin_taxon_name
            self.print_result_line(fo, output)

        if fo:
            fo.close()

        #############################################
        #
        # EPA-PTP species delimitation
        #
        #############################################
        if ptp:
            self.run_ptp(jp)

    def novelty_check(self, place_edge, ranks, lws):
        """If the taxonomic assignment is not assigned to the genus level, 
        we need to check if it is due to the incomplete reference taxonomy or 
        it is likely to be something new:
        
        1. If the final ranks are assinged because of lw cut, that means with samller lw
        the ranks can be further assinged to lowers. This indicate the undetermined ranks 
        in the assignment is not due to the incomplete reference taxonomy, so the query 
        sequence is likely to be something new.
        
        2. Otherwise We check all leaf nodes' immediate lower rank below this ml placement point, 
        if they are not empty, output all ranks and indicate this could be novelty.
        """

        lowrank = 0
        for i in max(range(len(ranks)), 6):
            """above genus level"""
            rk = ranks[i]
            lw = lws[i]
            if rk == "-":
                break
            else:
                lowrank = lowrank + 1
                if lw >= 0 and lw < self.cfg.min_lhw:
                    return True

        if lowrank >= 5 and lowrank < len(ranks) and not ranks[lowrank] == "-":
            return False
        else:
            placenode = self.reftree.search_nodes(B=place_edge)[0]
            if placenode.is_leaf():
                return False
            else:
                leafnodes = placenode.get_leaves()
                flag = True
                for leaf in leafnodes:
                    br_num = leaf.B
                    branks = self.bid_taxonomy_map[br_num]
                    if lowrank >= len(branks) or branks[lowrank] == "-":
                        flag = False
                        break

                return flag
Ejemplo n.º 2
0
class LeaveOneTest:
    def __init__(self, config):
        self.cfg = config
        
        self.mis_fname = self.cfg.out_fname("%NAME%.mis")
        self.premis_fname = self.cfg.out_fname("%NAME%.premis")
        self.misrank_fname = self.cfg.out_fname("%NAME%.misrank")
        self.stats_fname = self.cfg.out_fname("%NAME%.stats")
        
        if os.path.isfile(self.mis_fname):
            print "\nERROR: Output file already exists: %s" % self.mis_fname
            print "Please specify a different job name using -n or remove old output files."
            self.cfg.exit_user_error()

        self.tmp_refaln = config.tmp_fname("%NAME%.refaln")
        self.reftree_lbl_fname = config.tmp_fname("%NAME%_lbl.tre")
        self.reftree_tax_fname = config.tmp_fname("%NAME%_tax.tre")
        self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
        self.reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre")

        self.mislabels = []
        self.mislabels_cnt = []
        self.rank_mislabels = []
        self.rank_mislabels_cnt = []
        self.misrank_conf_map = {}
        
    def write_bid_tax_map(self, bid_tax_map, final):
        if self.cfg.debug:
            fname_suffix = "final" if final else "l1out"
            bid_fname = self.cfg.tmp_fname("%NAME%_" + "bid_tax_map_%s.txt" % fname_suffix)
            with open(bid_fname, "w") as outf:
              for bid, bid_rec in bid_tax_map.iteritems():
                outf.write("%s\t%s\t%d\t%f\n" % (bid, bid_rec[0], bid_rec[1], bid_rec[2]));    

    def write_assignments(self, assign_map, final):
        if self.cfg.debug:
            fname_suffix = "final" if final else "l1out"
            assign_fname = self.cfg.tmp_fname("%NAME%_" + "taxassign_%s.txt" % fname_suffix)
            with open(assign_fname, "w") as outf:
                for seq_name in assign_map.iterkeys():
                    ranks, lws = assign_map[seq_name]
                    outf.write("%s\t%s\t%s\n" % (seq_name, ";".join(ranks), ";".join(["%.3f" % l for l in lws])))

    def load_refjson(self, refjson_fname):
        try:
            self.refjson = RefJsonParser(refjson_fname)
        except ValueError:
            self.cfg.exit_user_error("ERROR: Invalid json file format!")
            
        #validate input json format 
        (valid, err) = self.refjson.validate()
        if not valid:
            self.cfg.log.error("ERROR: Parsing reference JSON file failed:\n%s", err)
            self.cfg.exit_user_error()
        
        self.rate = self.refjson.get_rate()
        self.node_height = self.refjson.get_node_height()
        self.origin_taxonomy = self.refjson.get_origin_taxonomy()
        self.tax_tree = self.refjson.get_tax_tree()
        self.cfg.compress_patterns = self.refjson.get_pattern_compression()

        self.bid_taxonomy_map = self.refjson.get_branch_tax_map()
        if not self.bid_taxonomy_map:
            # old file format (before 1.6), need to rebuild this map from scratch
            th = TaxTreeHelper(self.cfg, self.origin_taxonomy)
            th.set_mf_rooted_tree(self.tax_tree)
            th.set_bf_unrooted_tree(self.refjson.get_reftree())
            self.bid_taxonomy_map = th.get_bid_taxonomy_map()
            
        self.write_bid_tax_map(self.bid_taxonomy_map, final=False)

        reftree_str = self.refjson.get_raxml_readable_tree()
        self.reftree = Tree(reftree_str)
        self.reftree_size = len(self.reftree.get_leaves())

        # IMPORTANT: set EPA heuristic rate based on tree size!                
        self.cfg.resolve_auto_settings(self.reftree_size)
        # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file        
        if self.cfg.epa_load_optmod:
            self.cfg.raxml_model = self.refjson.get_ratehet_model()

        self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height)
        self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy, self.tax_tree)
        
        tax_code_name = self.refjson.get_taxcode()
        self.tax_code = TaxCode(tax_code_name)
        
        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_map=self.origin_taxonomy)
        self.tax_common_ranks = self.taxonomy.get_common_ranks()
#        print "Common ranks: ", self.tax_common_ranks

        self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
        self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
        
    def run_epa_trainer(self):
        epa_trainer.run_trainer(self.cfg)

        if not os.path.isfile(self.cfg.refjson_fname):
            self.cfg.log.error("\nBuilding reference tree failed, see error messages above.")
            self.cfg.exit_fatal_error()
        
    def classify_seq(self, placement):
        edges = placement["p"]
        if len(edges) > 0:
            return self.classify_helper.classify_seq(edges)
        else:
            print "ERROR: no placements! something is definitely wrong!"

    def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws):
        mis_rec = None
        
        num_common_ranks = len(self.tax_common_ranks)
        orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
        new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks)
        #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks):
#        if new_rank_level < 0:
        if len(ranks) == 0:
            mis_rec = {}
            mis_rec['name'] = seq_name
            mis_rec['orig_level'] = -1
            mis_rec['real_level'] = 0
            mis_rec['level_name'] = "[NotIngroup]"
            mis_rec['inv_level'] = -1 * mis_rec['real_level']  # just for sorting
            mis_rec['orig_ranks'] = orig_ranks
            mis_rec['ranks'] = []
            mis_rec['lws'] = [1.0]
            mis_rec['conf'] = mis_rec['lws'][0]
        else:
            mislabel_lvl = -1
            min_len = min(len(orig_ranks),len(ranks))
            for rank_lvl in range(min_len):
                if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]:
                    mislabel_lvl = rank_lvl
                    break

            if mislabel_lvl >= 0:
                real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl)
                mis_rec = {}
                mis_rec['name'] = seq_name
                mis_rec['orig_level'] = mislabel_lvl
                mis_rec['real_level'] = real_lvl
                mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0]
                mis_rec['inv_level'] = -1 * mis_rec['real_level']  # just for sorting
                mis_rec['orig_ranks'] = orig_ranks
                mis_rec['ranks'] = ranks
                mis_rec['lws'] = lws
                mis_rec['conf'] = lws[mislabel_lvl]
    
        if mis_rec:
            self.mislabels.append(mis_rec)
            
        return mis_rec
        
    def filter_mislabels(self):
        filtered_mis = []
        for i in range(len(self.mislabels)):
            if self.mislabels[i]['conf'] >= self.cfg.conf_cutoff:
                filtered_mis.append(self.mislabels[i])
        
        self.mislabels = filtered_mis

    def check_rank_tax_labels(self, rank_name, orig_ranks, ranks, lws):
        mislabel_lvl = -1
        min_len = min(len(orig_ranks),len(ranks))
        for rank_lvl in range(min_len):
            if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]:
                mislabel_lvl = rank_lvl
                break

        if mislabel_lvl >= 0:
            real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl)
            mis_rec = {}
            mis_rec['name'] = rank_name
            mis_rec['orig_level'] = mislabel_lvl
            mis_rec['real_level'] = real_lvl
            mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0]
            mis_rec['inv_level'] = -1 * real_lvl  # just for sorting
            mis_rec['orig_ranks'] = orig_ranks
            mis_rec['ranks'] = ranks
            mis_rec['lws'] = lws
            mis_rec['conf'] = lws[mislabel_lvl]
            self.rank_mislabels.append(mis_rec)
               
            return mis_rec
        else:
            return None                

    def mis_rec_to_string_old(self, mis_rec):
        lvl = mis_rec['orig_level']
        output = mis_rec['name'] + "\t"
        output += "%s\t%s\t%s\t%.3f\n" % (mis_rec['level_name'], 
            mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl])
        output += ";".join(mis_rec['orig_ranks']) + "\n"
        output += ";".join(mis_rec['ranks']) + "\n"
        output += "\t".join(["%.3f" % conf for conf in mis_rec['lws']]) + "\n"
        return output

    def mis_rec_to_string(self, mis_rec):
        lvl = mis_rec['orig_level']
        uncorr_name = EpacConfig.strip_ref_prefix(self.refjson.get_uncorr_seqid(mis_rec['name']))
        uncorr_orig_ranks = self.refjson.get_uncorr_ranks(mis_rec['orig_ranks'])
        uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks'])
        output = uncorr_name + "\t"
      
        if lvl >= 0:
            output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], 
                uncorr_orig_ranks[lvl], uncorr_ranks[lvl], mis_rec['lws'][lvl])
        else:
            output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], 
                "NA", "NA", mis_rec['lws'][0])
        
        output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t"
        output += Taxonomy.lineage_str(uncorr_ranks) + "\t"
        output += ";".join(["%.3f" % conf for conf in mis_rec['lws']])
        if 'rank_conf' in mis_rec:
            output += "\t%.3f" % mis_rec['rank_conf']
        return output

    def sort_mislabels(self):
        self.mislabels = sorted(self.mislabels, key=itemgetter('inv_level', 'conf', 'name'), reverse=True)
        for mis_rec in self.mislabels:
            real_lvl = mis_rec["real_level"]
            self.mislabels_cnt[real_lvl] += 1
        
        if self.cfg.ranktest:
            self.rank_mislabels = sorted(self.rank_mislabels, key=itemgetter('inv_level', 'conf', 'name'), reverse=True)
            for mis_rec in self.rank_mislabels:
                real_lvl = mis_rec["real_level"]
                self.rank_mislabels_cnt[real_lvl] += 1
    
    def write_stats(self, toFile=False):
        self.cfg.log.info("Mislabeled sequences by rank:")
        seq_sum = 0
        rank_sum = 0
        stats = []
        for i in range(len(self.mislabels_cnt)):
            if i > 0:
                rname = self.tax_code.rank_level_name(i)[0].ljust(12)
            else:
                rname = "[NotIngroup]"
            if self.mislabels_cnt[i] > 0:
                seq_sum += self.mislabels_cnt[i]
#                    output = "%s:\t%d" % (rname, seq_sum)
                output = "%s:\t%d" % (rname, self.mislabels_cnt[i])
                if self.cfg.ranktest:
                    rank_sum += self.rank_mislabels_cnt[i]
                    output += "\t%d" % rank_sum
                self.cfg.log.info(output) 
                stats.append(output)

        if toFile:
            with open(self.stats_fname, "w") as fo_stat:
                for line in stats:
                    fo_stat.write(line + "\n")
    
    def write_mislabels(self, final=True):
        if final:
            out_fname = self.mis_fname
        else:
            out_fname = self.premis_fname
        
        with open(out_fname, "w") as fo_all:
            fields = ["SeqID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"]
            if self.cfg.ranktest:
                fields += ["HigherRankMisplacedConfidence"]
            header = ";" + "\t".join(fields) + "\n"
            fo_all.write(header)
            if self.cfg.verbose and len(self.mislabels) > 0 and final:
                print "Mislabeled sequences:\n"
                print header 
            for mis_rec in self.mislabels:
                output = self.mis_rec_to_string(mis_rec)  + "\n"
                fo_all.write(output)
                if self.cfg.verbose and final:
                    print(output) 
                    
        if not final:
            return

        if self.cfg.ranktest:
            with open(self.misrank_fname, "w") as fo_all:
                fields = ["RankID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"]
                header = ";" + "\t".join(fields)  + "\n"
                fo_all.write(header)
                if self.cfg.verbose  and len(self.rank_mislabels) > 0:
                    print "\nMislabeled higher ranks:\n"
                    print header 
                for mis_rec in self.rank_mislabels:
                    output = self.mis_rec_to_string(mis_rec) + "\n"
                    fo_all.write(output)
                    if self.cfg.verbose:
                        print(output) 
                        
        self.write_stats()
   
    def run_leave_subtree_out_test(self):
        job_name = self.cfg.subst_name("l1out_rank_%NAME%")
#        if self.jplace_fname:
#            jp = EpaJsonParser(self.jplace_fname)
#        else:        

        #create file with subtrees
        rank_tips = {}
        rank_parent = {}
        for node in self.tax_tree.traverse("postorder"):
            if node.is_leaf() or node.is_root():
                continue
            tax_path = node.name
            ranks = Taxonomy.split_rank_uid(tax_path)
            rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks)
            if rank_lvl < 2:
                continue
                
            parent_ranks = Taxonomy.split_rank_uid(node.up.name)
            parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks)
            if parent_lvl < 1:
                continue
            
            rank_seqs = node.get_leaf_names()
            rank_size = len(rank_seqs)
            if rank_size < 2 or rank_size > self.reftree_size-4:
                continue

#            print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n"
            rank_tips[tax_path] = node.get_leaf_names()
            rank_parent[tax_path] = parent_ranks
                
        subtree_list = rank_tips.items()
        
        if len(subtree_list) == 0:
            return 0
            
        subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt")
        with open(subtree_list_file, "w") as fout:
            for rank_name, tips in subtree_list:
                fout.write("%s\n" % " ".join(tips))
        
        jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, 
            mode="l1o_subtree", subtree_fname=subtree_list_file)

        subtree_count = 0
        for jp in jp_list:
            placements = jp.get_placement()
            for place in placements:
                ranks, lws = self.classify_seq(place)
                tax_path = subtree_list[subtree_count][0]
                orig_ranks = Taxonomy.split_rank_uid(tax_path)
                rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
                rank_prefix = self.guess_rank_level_name(orig_ranks, rank_level)[0]
                rank_name = orig_ranks[rank_level]
                if not rank_name.startswith(rank_prefix):
                    rank_name = rank_prefix + rank_name
                parent_ranks = rank_parent[tax_path]
#                print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n"
                mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws)
                if mis_rec:
                    self.misrank_conf_map[tax_path] = mis_rec['conf']
                subtree_count += 1

        return subtree_count    
        
    def run_leave_seq_out_test(self):
        job_name = self.cfg.subst_name("l1out_seq_%NAME%")
        placements = []
        if self.cfg.jplace_fname:
            if os.path.isdir(self.cfg.jplace_fname):
                jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace')
            else:
                jplace_fmask = self.cfg.jplace_fname

            jplace_fname_list = glob.glob(jplace_fmask)
            for jplace_fname in jplace_fname_list:
                jp = EpaJsonParser(jplace_fname)
                placements += jp.get_placement()
                
            config.log.debug("Loaded %d placements from %s\n", len(placements), jplace_fmask)
        else:        
            jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq")
            placements = jp.get_placement()
            if self.cfg.output_interim_files:
                out_jplace_fname = self.cfg.out_fname("%NAME%.l1out_seq.jplace")
                self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True, mode="l1o_seq")
        
        seq_count = 0
        l1out_ass = {}
        for place in placements:
            seq_name = place["n"][0]
            
            # get original taxonomic label
#            orig_ranks = self.get_orig_ranks(seq_name)
            orig_ranks =  self.taxtree_helper.get_seq_ranks_from_tree(seq_name)

            # get EPA tax label
            ranks, lws = self.classify_seq(place)
            l1out_ass[seq_name] = (ranks, lws)
            
            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws)
            # cross-check with higher rank mislabels
            if self.cfg.ranktest and mis_rec:
                rank_conf = 0
                for lvl in range(2,len(orig_ranks)):
                    tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl)
                    if tax_path in self.misrank_conf_map:
                        rank_conf = max(rank_conf, self.misrank_conf_map[tax_path])
                mis_rec['rank_conf'] = rank_conf
            seq_count += 1

        self.write_assignments(l1out_ass, final=False)
            
        return seq_count    
        
    def run_final_epa_test(self):
        self.reftree_outgroup = self.refjson.get_outgroup()

        tmp_reftree = self.reftree.copy(method="newick") 
        name2refnode = {}
        for leaf in tmp_reftree.iter_leaves():
            name2refnode[leaf.name] = leaf        

        tmp_taxtree = self.tax_tree.copy(method="newick") 
        name2taxnode = {}
        for leaf in tmp_taxtree.iter_leaves():
            name2taxnode[leaf.name] = leaf        

        for mis_rec in self.mislabels:
            rname = mis_rec['name']
#            rname = EpacConfig.REF_SEQ_PREFIX + name

            if rname in name2refnode:
                name2refnode[rname].delete()
            else:
                print "Node not found in the reference tree: %s" % rname

            if rname in name2taxnode:
                name2taxnode[rname].delete()
            else:
                print "Node not found in the taxonomic tree: %s" % rname

        # remove unifurcation at the root
        if len(tmp_reftree.children) == 1:
            tmp_reftree = tmp_reftree.children[0]
            
        self.mislabels = []

        th = TaxTreeHelper(self.cfg, self.origin_taxonomy)
        th.set_mf_rooted_tree(tmp_taxtree)
            
        epa_result = self.run_epa_once(tmp_reftree)
        
        reftree_epalbl_str = epa_result.get_std_newick_tree()        
        placements = epa_result.get_placement()
        
        # update branchid-taxonomy mapping to account for possible changes in branch numbering
        reftree_tax = Tree(reftree_epalbl_str)
        th.set_bf_unrooted_tree(reftree_tax)
        bid_tax_map = th.get_bid_taxonomy_map()
        
        self.write_bid_tax_map(bid_tax_map, final=True)

        cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.rate, self.node_height)
        
#        newtax_fname = self.cfg.subst_name("newtax_%NAME%.tre")
#        th.get_tax_tree().write(outfile=newtax_fname, format=3)

        final_ass = {}
        for place in placements:
            seq_name = place["n"][0]

            # get original taxonomic label
            orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name)

            # EXPERIMENTAL FEATURE - disabled for now!
            # It could happen that certain ranks were present in the "original" reference tree, but 
            # are completely missing in the pruned tree (e.g., all seqs of a species were considered "suspicious" 
            # after the leave-one-out test and thus pruned)
            # In this case, EPA has no chance to infer full original taxonomic annotation (=species) since the corresponding clade
            # is now missing. To account for this fact, we amend the original taxonomic annotation and set ranks missing from  
            # pruned tree to "Undefined".
#            orig_ranks = th.strip_missing_ranks(orig_ranks)
#            print orig_ranks

            # get EPA tax label
            ranks, lws = cl.classify_seq(place["p"])
            final_ass[seq_name] = (ranks, lws)

            #print seq_name, ": ", orig_ranks, "--->", ranks

            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws)

        self.write_assignments(final_ass, final=True)

    def run_epa_once(self, reftree):
        reftree_fname = self.cfg.tmp_fname("final_ref_%NAME%.tre")
        job_name = self.cfg.subst_name("final_epa_%NAME%")

        reftree.write(outfile=reftree_fname)

        # IMPORTANT: don't load the model, since it's invalid for the pruned true !!! 
        optmod_fname=""
        epa_result = self.raxml.run_epa(job_name, self.refalign_fname, reftree_fname, optmod_fname)

        if self.cfg.output_interim_files:
            out_jplace_fname = self.cfg.out_fname("%NAME%.final_epa.jplace")
            self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True)

        return epa_result

    def run_test(self):
        self.raxml = RaxmlWrapper(self.cfg)

#        config.log.info("Number of sequences in the reference: %d\n", self.reftree_size)

        self.refjson.get_raxml_readable_tree(self.reftree_fname)
        self.refalign_fname = self.refjson.get_alignment(self.tmp_refaln)        
        self.refjson.get_binary_model(self.optmod_fname)
        
        if self.cfg.ranktest:
            config.log.info("Running the leave-one-rank-out test...\n")
            subtree_count = self.run_leave_subtree_out_test()
            
        config.log.info("Running the leave-one-sequence-out test...\n")
        self.run_leave_seq_out_test()

        if len(self.mislabels) > 0:
            config.log.info("Leave-one-out test identified %d suspicious sequences; running final EPA test to check them...\n", len(self.mislabels))
            if self.cfg.debug:
                self.write_mislabels(final=False)
            self.run_final_epa_test()

        self.filter_mislabels()
        self.sort_mislabels()
        self.write_mislabels()
        config.log.info("\nTotal mislabels: %d / %.2f %%", len(self.mislabels), (float(len(self.mislabels)) / self.reftree_size * 100))
Ejemplo n.º 3
0
class LeaveOneTest:
    def __init__(self, config):
        self.cfg = config

        self.mis_fname = self.cfg.out_fname("%NAME%.mis")
        self.premis_fname = self.cfg.out_fname("%NAME%.premis")
        self.misrank_fname = self.cfg.out_fname("%NAME%.misrank")
        self.stats_fname = self.cfg.out_fname("%NAME%.stats")

        if os.path.isfile(self.mis_fname):
            print("\nERROR: Output file already exists: %s" % self.mis_fname)
            print(
                "Please specify a different job name using -n or remove old output files."
            )
            self.cfg.exit_user_error()

        self.tmp_refaln = config.tmp_fname("%NAME%.refaln")
        self.reftree_lbl_fname = config.tmp_fname("%NAME%_lbl.tre")
        self.reftree_tax_fname = config.tmp_fname("%NAME%_tax.tre")
        self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
        self.reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre")

        self.mislabels = []
        self.mislabels_cnt = []
        self.rank_mislabels = []
        self.rank_mislabels_cnt = []
        self.misrank_conf_map = {}

    def write_bid_tax_map(self, bid_tax_map, final):
        if self.cfg.debug:
            fname_suffix = "final" if final else "l1out"
            bid_fname = self.cfg.tmp_fname("%NAME%_" +
                                           "bid_tax_map_%s.txt" % fname_suffix)
            with open(bid_fname, "w") as outf:
                for bid, bid_rec in bid_tax_map.items():
                    outf.write("%s\t%s\t%d\t%f\n" %
                               (bid, bid_rec[0], bid_rec[1], bid_rec[2]))

    def write_assignments(self, assign_map, final):
        if self.cfg.debug:
            fname_suffix = "final" if final else "l1out"
            assign_fname = self.cfg.tmp_fname("%NAME%_" + "taxassign_%s.txt" %
                                              fname_suffix)
            with open(assign_fname, "w") as outf:
                for seq_name in assign_map.keys():
                    ranks, lws = assign_map[seq_name]
                    outf.write("%s\t%s\t%s\n" %
                               (seq_name, ";".join(ranks), ";".join(
                                   ["%.3f" % l for l in lws])))

    def load_refjson(self, refjson_fname):
        try:
            self.refjson = RefJsonParser(refjson_fname)
        except ValueError:
            self.cfg.exit_user_error("ERROR: Invalid json file format!")

        #validate input json format
        (valid, err) = self.refjson.validate()
        if not valid:
            self.cfg.log.error(
                "ERROR: Parsing reference JSON file failed:\n%s", err)
            self.cfg.exit_user_error()

        self.rate = self.refjson.get_rate()
        self.node_height = self.refjson.get_node_height()
        self.origin_taxonomy = self.refjson.get_origin_taxonomy()
        self.tax_tree = self.refjson.get_tax_tree()
        self.cfg.compress_patterns = self.refjson.get_pattern_compression()

        self.bid_taxonomy_map = self.refjson.get_branch_tax_map()
        if not self.bid_taxonomy_map:
            # old file format (before 1.6), need to rebuild this map from scratch
            th = TaxTreeHelper(self.cfg, self.origin_taxonomy)
            th.set_mf_rooted_tree(self.tax_tree)
            th.set_bf_unrooted_tree(self.refjson.get_reftree())
            self.bid_taxonomy_map = th.get_bid_taxonomy_map()

        self.write_bid_tax_map(self.bid_taxonomy_map, final=False)

        reftree_str = self.refjson.get_raxml_readable_tree()
        self.reftree = Tree(reftree_str)
        self.reftree_size = len(self.reftree.get_leaves())

        # IMPORTANT: set EPA heuristic rate based on tree size!
        self.cfg.resolve_auto_settings(self.reftree_size)
        # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file
        if self.cfg.epa_load_optmod:
            self.cfg.raxml_model = self.refjson.get_ratehet_model()

        self.classify_helper = TaxClassifyHelper(self.cfg,
                                                 self.bid_taxonomy_map,
                                                 self.rate, self.node_height)
        self.taxtree_helper = TaxTreeHelper(self.cfg, self.origin_taxonomy,
                                            self.tax_tree)

        tax_code_name = self.refjson.get_taxcode()
        self.tax_code = TaxCode(tax_code_name)

        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX,
                                 tax_map=self.origin_taxonomy)
        self.tax_common_ranks = self.taxonomy.get_common_ranks()
        #        print "Common ranks: ", self.tax_common_ranks

        self.mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS
        self.rank_mislabels_cnt = [0] * TaxCode.UNI_TAX_LEVELS

    def run_epa_trainer(self):
        epa_trainer.run_trainer(self.cfg)

        if not os.path.isfile(self.cfg.refjson_fname):
            self.cfg.log.error(
                "\nBuilding reference tree failed, see error messages above.")
            self.cfg.exit_fatal_error()

    def classify_seq(self, placement):
        edges = placement["p"]
        if len(edges) > 0:
            return self.classify_helper.classify_seq(edges)
        else:
            print("ERROR: no placements! something is definitely wrong!")

    def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws):
        mis_rec = None

        num_common_ranks = len(self.tax_common_ranks)
        orig_rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
        new_rank_level = Taxonomy.lowest_assigned_rank_level(ranks)
        #if new_rank_level < 0 or (new_rank_level < num_common_ranks and orig_rank_level >= num_common_ranks):
        #        if new_rank_level < 0:
        if len(ranks) == 0:
            mis_rec = {}
            mis_rec['name'] = seq_name
            mis_rec['orig_level'] = -1
            mis_rec['real_level'] = 0
            mis_rec['level_name'] = "[NotIngroup]"
            mis_rec['inv_level'] = -1 * mis_rec[
                'real_level']  # just for sorting
            mis_rec['orig_ranks'] = orig_ranks
            mis_rec['ranks'] = []
            mis_rec['lws'] = [1.0]
            mis_rec['conf'] = mis_rec['lws'][0]
        else:
            mislabel_lvl = -1
            min_len = min(len(orig_ranks), len(ranks))
            for rank_lvl in range(min_len):
                if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[
                        rank_lvl] != orig_ranks[rank_lvl]:
                    mislabel_lvl = rank_lvl
                    break

            if mislabel_lvl >= 0:
                real_lvl = self.tax_code.guess_rank_level(
                    orig_ranks, mislabel_lvl)
                mis_rec = {}
                mis_rec['name'] = seq_name
                mis_rec['orig_level'] = mislabel_lvl
                mis_rec['real_level'] = real_lvl
                mis_rec['level_name'] = self.tax_code.rank_level_name(
                    real_lvl)[0]
                mis_rec['inv_level'] = -1 * mis_rec[
                    'real_level']  # just for sorting
                mis_rec['orig_ranks'] = orig_ranks
                mis_rec['ranks'] = ranks
                mis_rec['lws'] = lws
                mis_rec['conf'] = lws[mislabel_lvl]

        if mis_rec:
            self.mislabels.append(mis_rec)

        return mis_rec

    def filter_mislabels(self):
        filtered_mis = []
        for i in range(len(self.mislabels)):
            if self.mislabels[i]['conf'] >= self.cfg.conf_cutoff:
                filtered_mis.append(self.mislabels[i])

        self.mislabels = filtered_mis

    def check_rank_tax_labels(self, rank_name, orig_ranks, ranks, lws):
        mislabel_lvl = -1
        min_len = min(len(orig_ranks), len(ranks))
        for rank_lvl in range(min_len):
            if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[
                    rank_lvl] != orig_ranks[rank_lvl]:
                mislabel_lvl = rank_lvl
                break

        if mislabel_lvl >= 0:
            real_lvl = self.tax_code.guess_rank_level(orig_ranks, mislabel_lvl)
            mis_rec = {}
            mis_rec['name'] = rank_name
            mis_rec['orig_level'] = mislabel_lvl
            mis_rec['real_level'] = real_lvl
            mis_rec['level_name'] = self.tax_code.rank_level_name(real_lvl)[0]
            mis_rec['inv_level'] = -1 * real_lvl  # just for sorting
            mis_rec['orig_ranks'] = orig_ranks
            mis_rec['ranks'] = ranks
            mis_rec['lws'] = lws
            mis_rec['conf'] = lws[mislabel_lvl]
            self.rank_mislabels.append(mis_rec)

            return mis_rec
        else:
            return None

    def mis_rec_to_string_old(self, mis_rec):
        lvl = mis_rec['orig_level']
        output = mis_rec['name'] + "\t"
        output += "%s\t%s\t%s\t%.3f\n" % (
            mis_rec['level_name'], mis_rec['orig_ranks'][lvl],
            mis_rec['ranks'][lvl], mis_rec['lws'][lvl])
        output += ";".join(mis_rec['orig_ranks']) + "\n"
        output += ";".join(mis_rec['ranks']) + "\n"
        output += "\t".join(["%.3f" % conf for conf in mis_rec['lws']]) + "\n"
        return output

    def mis_rec_to_string(self, mis_rec):
        lvl = mis_rec['orig_level']
        uncorr_name = EpacConfig.strip_ref_prefix(
            self.refjson.get_uncorr_seqid(mis_rec['name']))
        uncorr_orig_ranks = self.refjson.get_uncorr_ranks(
            mis_rec['orig_ranks'])
        uncorr_ranks = self.refjson.get_uncorr_ranks(mis_rec['ranks'])
        output = uncorr_name + "\t"

        if lvl >= 0:
            output += "%s\t%s\t%s\t%.3f\t" % (
                mis_rec['level_name'], uncorr_orig_ranks[lvl],
                uncorr_ranks[lvl], mis_rec['lws'][lvl])
        else:
            output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], "NA",
                                              "NA", mis_rec['lws'][0])

        output += Taxonomy.lineage_str(uncorr_orig_ranks) + "\t"
        output += Taxonomy.lineage_str(uncorr_ranks) + "\t"
        output += ";".join(["%.3f" % conf for conf in mis_rec['lws']])
        if 'rank_conf' in mis_rec:
            output += "\t%.3f" % mis_rec['rank_conf']
        return output

    def sort_mislabels(self):
        self.mislabels = sorted(self.mislabels,
                                key=itemgetter('inv_level', 'conf', 'name'),
                                reverse=True)
        for mis_rec in self.mislabels:
            real_lvl = mis_rec["real_level"]
            self.mislabels_cnt[real_lvl] += 1

        if self.cfg.ranktest:
            self.rank_mislabels = sorted(self.rank_mislabels,
                                         key=itemgetter(
                                             'inv_level', 'conf', 'name'),
                                         reverse=True)
            for mis_rec in self.rank_mislabels:
                real_lvl = mis_rec["real_level"]
                self.rank_mislabels_cnt[real_lvl] += 1

    def write_stats(self, toFile=False):
        self.cfg.log.info("Mislabeled sequences by rank:")
        seq_sum = 0
        rank_sum = 0
        stats = []
        for i in range(len(self.mislabels_cnt)):
            if i > 0:
                rname = self.tax_code.rank_level_name(i)[0].ljust(12)
            else:
                rname = "[NotIngroup]"
            if self.mislabels_cnt[i] > 0:
                seq_sum += self.mislabels_cnt[i]
                #                    output = "%s:\t%d" % (rname, seq_sum)
                output = "%s:\t%d" % (rname, self.mislabels_cnt[i])
                if self.cfg.ranktest:
                    rank_sum += self.rank_mislabels_cnt[i]
                    output += "\t%d" % rank_sum
                self.cfg.log.info(output)
                stats.append(output)

        if toFile:
            with open(self.stats_fname, "w") as fo_stat:
                for line in stats:
                    fo_stat.write(line + "\n")

    def write_mislabels_header(self, fo, final, fields):
        header = ";" + "\t".join(fields) + "\n"

        # write to file
        if final:
            for line in DISCLAIMER.split("\n"):
                fo.write(";%s\n" % line)
            fo.write(";\n")
        fo.write(header)

        # print to console
        if final and self.cfg.verbose and len(self.rank_mislabels) > 0:
            print(DISCLAIMER, "\n")
            print("Mislabeled sequences:\n")
            print(header)

    def write_rank_mislabels(self, final=True):
        if not self.cfg.ranktest:
            return

        with open(self.misrank_fname, "w") as fo_all:
            fields = [
                "RankID", "MislabeledLevel", "OriginalLabel", "ProposedLabel",
                "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath",
                "PerRankConfidence"
            ]
            self.write_mislabels_header(fo_all, final, fields)
            for mis_rec in self.rank_mislabels:
                output = self.mis_rec_to_string(mis_rec) + "\n"
                fo_all.write(output)
                if self.cfg.verbose:
                    print(output)

    def write_mislabels(self, final=True):
        if final:
            out_fname = self.mis_fname
        else:
            out_fname = self.premis_fname

        with open(out_fname, "w") as fo_all:
            fields = [
                "SeqID", "MislabeledLevel", "OriginalLabel", "ProposedLabel",
                "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath",
                "PerRankConfidence"
            ]
            if self.cfg.ranktest:
                fields += ["HigherRankMisplacedConfidence"]
            self.write_mislabels_header(fo_all, final, fields)
            for mis_rec in self.mislabels:
                output = self.mis_rec_to_string(mis_rec) + "\n"
                fo_all.write(output)
                if self.cfg.verbose and final:
                    print(output)

        if final:
            self.write_rank_mislabels()
            self.write_stats()

    def get_parent_tip_ranks(self, tax_tree):
        rank_tips = {}
        rank_parent = {}
        for node in tax_tree.traverse("postorder"):
            if node.is_leaf() or node.is_root():
                continue
            tax_path = node.name
            ranks = Taxonomy.split_rank_uid(tax_path)
            rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks)
            if rank_lvl < 2:
                continue

            parent_ranks = Taxonomy.split_rank_uid(node.up.name)
            parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks)
            if parent_lvl < 1:
                continue

            rank_seqs = node.get_leaf_names()
            rank_size = len(rank_seqs)
            if rank_size < 2 or rank_size > self.reftree_size - 4:
                continue


#            print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n"
            rank_tips[tax_path] = node.get_leaf_names()
            rank_parent[tax_path] = parent_ranks

        return rank_parent, rank_tips

    def run_leave_subtree_out_test(self):
        job_name = self.cfg.subst_name("l1out_rank_%NAME%")
        #        if self.jplace_fname:
        #            jp = EpaJsonParser(self.jplace_fname)
        #        else:

        #create file with subtrees
        rank_parent, rank_tips = self.get_parent_tip_ranks(self.tax_tree)

        subtree_list = list(rank_tips.items())
        if len(subtree_list) == 0:
            return 0

        subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt")
        with open(subtree_list_file, "w") as fout:
            for rank_name, tips in subtree_list:
                fout.write("%s\n" % " ".join(tips))

        jp_list = self.raxml.run_epa(job_name,
                                     self.refalign_fname,
                                     self.reftree_fname,
                                     self.optmod_fname,
                                     mode="l1o_subtree",
                                     subtree_fname=subtree_list_file)

        subtree_count = 0
        for jp in jp_list:
            placements = jp.get_placement()
            for place in placements:
                ranks, lws = self.classify_seq(place)
                tax_path = subtree_list[subtree_count][0]
                orig_ranks = Taxonomy.split_rank_uid(tax_path)
                rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
                rank_prefix = self.tax_code.guess_rank_level_name(
                    orig_ranks, rank_level)[0]
                rank_name = orig_ranks[rank_level]
                if not rank_name.startswith(rank_prefix):
                    rank_name = rank_prefix + rank_name
                parent_ranks = rank_parent[tax_path]
                #                print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n"
                mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks,
                                                     ranks, lws)
                if mis_rec:
                    self.misrank_conf_map[tax_path] = mis_rec['conf']
                subtree_count += 1

        return subtree_count

    def run_leave_seq_out_test(self):
        job_name = self.cfg.subst_name("l1out_seq_%NAME%")
        placements = []
        if self.cfg.jplace_fname:
            if os.path.isdir(self.cfg.jplace_fname):
                jplace_fmask = os.path.join(self.cfg.jplace_fname, '*.jplace')
            else:
                jplace_fmask = self.cfg.jplace_fname

            jplace_fname_list = glob.glob(jplace_fmask)
            for jplace_fname in jplace_fname_list:
                jp = EpaJsonParser(jplace_fname)
                placements += jp.get_placement()

            config.log.debug("Loaded %d placements from %s\n", len(placements),
                             jplace_fmask)
        else:
            jp = self.raxml.run_epa(job_name,
                                    self.refalign_fname,
                                    self.reftree_fname,
                                    self.optmod_fname,
                                    mode="l1o_seq")
            placements = jp.get_placement()
            if self.cfg.output_interim_files:
                out_jplace_fname = self.cfg.out_fname(
                    "%NAME%.l1out_seq.jplace")
                self.raxml.copy_epa_jplace(job_name,
                                           out_jplace_fname,
                                           move=True,
                                           mode="l1o_seq")

        seq_count = 0
        l1out_ass = {}
        for place in placements:
            seq_name = place["n"][0]

            # get original taxonomic label
            #            orig_ranks = self.get_orig_ranks(seq_name)
            orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name)

            # get EPA tax label
            ranks, lws = self.classify_seq(place)
            l1out_ass[seq_name] = (ranks, lws)

            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks,
                                                lws)
            # cross-check with higher rank mislabels
            if self.cfg.ranktest and mis_rec:
                rank_conf = 0
                for lvl in range(2, len(orig_ranks)):
                    tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl)
                    if tax_path in self.misrank_conf_map:
                        rank_conf = max(rank_conf,
                                        self.misrank_conf_map[tax_path])
                mis_rec['rank_conf'] = rank_conf
            seq_count += 1

        self.write_assignments(l1out_ass, final=False)

        return seq_count

    def prune_mislabels_from_tree(self, src_tree, tree_name):
        pruned_tree = src_tree.copy(method="newick")
        name2node = {}
        for leaf in pruned_tree.iter_leaves():
            name2node[leaf.name] = leaf

        for mis_rec in self.mislabels:
            rname = mis_rec['name']
            #            rname = EpacConfig.REF_SEQ_PREFIX + name

            if rname in name2node:
                name2node[rname].delete()
            else:
                config.log.debug("Node not found in the %s tree: %s" %
                                 (tree_name, rname))

        return pruned_tree

    def run_final_epa_test(self):
        self.reftree_outgroup = self.refjson.get_outgroup()

        pruned_reftree = self.prune_mislabels_from_tree(
            self.reftree, "reference")
        pruned_taxtree = self.prune_mislabels_from_tree(
            self.reftree, "taxonomic")

        # remove unifurcation at the root
        if len(pruned_reftree.children) == 1:
            pruned_reftree = pruned_reftree.children[0]

        self.mislabels = []

        th = TaxTreeHelper(self.cfg, self.origin_taxonomy)
        th.set_mf_rooted_tree(pruned_taxtree)

        reftree_epalbl_str = None
        if self.cfg.final_jplace_fname:
            if os.path.isdir(self.cfg.final_jplace_fname):
                jplace_fmask = os.path.join(self.cfg.final_jplace_fname,
                                            '*.jplace')
            else:
                jplace_fmask = self.cfg.final_jplace_fname

            jplace_fname_list = glob.glob(jplace_fmask)
            placements = []
            for jplace_fname in jplace_fname_list:
                jp = EpaJsonParser(jplace_fname)
                placements += jp.get_placement()
                if not reftree_epalbl_str:
                    reftree_epalbl_str = jp.get_std_newick_tree()

            config.log.debug("Loaded %d final epa placements from %s\n",
                             len(placements), jplace_fmask)
        else:
            epa_result = self.run_epa_once(pruned_reftree)
            reftree_epalbl_str = epa_result.get_std_newick_tree()
            placements = epa_result.get_placement()

        # update branchid-taxonomy mapping to account for possible changes in branch numbering
        reftree_tax = Tree(reftree_epalbl_str)
        th.set_bf_unrooted_tree(reftree_tax)
        bid_tax_map = th.get_bid_taxonomy_map()

        self.write_bid_tax_map(bid_tax_map, final=True)

        cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.rate,
                               self.node_height)

        #        newtax_fname = self.cfg.subst_name("newtax_%NAME%.tre")
        #        th.get_tax_tree().write(outfile=newtax_fname, format=3)

        final_ass = {}
        for place in placements:
            seq_name = place["n"][0]

            # get original taxonomic label
            orig_ranks = self.taxtree_helper.get_seq_ranks_from_tree(seq_name)

            # EXPERIMENTAL FEATURE - disabled for now!
            # It could happen that certain ranks were present in the "original" reference tree, but
            # are completely missing in the pruned tree (e.g., all seqs of a species were considered "suspicious"
            # after the leave-one-out test and thus pruned)
            # In this case, EPA has no chance to infer full original taxonomic annotation (=species) since the corresponding clade
            # is now missing. To account for this fact, we amend the original taxonomic annotation and set ranks missing from
            # pruned tree to "Undefined".
            #            orig_ranks = th.strip_missing_ranks(orig_ranks)
            #            print orig_ranks

            # get EPA tax label
            ranks, lws = cl.classify_seq(place["p"])
            final_ass[seq_name] = (ranks, lws)

            #print seq_name, ": ", orig_ranks, "--->", ranks

            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks,
                                                lws)

        self.write_assignments(final_ass, final=True)

    def run_epa_once(self, reftree):
        reftree_fname = self.cfg.tmp_fname("final_ref_%NAME%.tre")
        job_name = self.cfg.subst_name("final_epa_%NAME%")

        reftree.write(outfile=reftree_fname)

        # IMPORTANT: don't load the model, since it's invalid for the pruned true !!!
        optmod_fname = ""
        epa_result = self.raxml.run_epa(job_name, self.refalign_fname,
                                        reftree_fname, optmod_fname)

        if self.cfg.output_interim_files:
            out_jplace_fname = self.cfg.out_fname("%NAME%.final_epa.jplace")
            self.raxml.copy_epa_jplace(job_name, out_jplace_fname, move=True)

        return epa_result

    def run_test(self):
        self.raxml = RaxmlWrapper(self.cfg)

        #        config.log.info("Number of sequences in the reference: %d\n", self.reftree_size)

        self.refjson.get_raxml_readable_tree(self.reftree_fname)
        self.refalign_fname = self.refjson.get_alignment(self.tmp_refaln)
        self.refjson.get_binary_model(self.optmod_fname)

        if self.cfg.ranktest:
            config.log.info("Running the leave-one-rank-out test...\n")
            subtree_count = self.run_leave_subtree_out_test()

        config.log.info("Running the leave-one-sequence-out test...\n")
        self.run_leave_seq_out_test()

        if len(self.mislabels) > 0:
            config.log.info(
                "Leave-one-out test identified %d suspicious sequences; running final EPA test to check them...\n",
                len(self.mislabels))
            if self.cfg.debug:
                self.write_mislabels(final=False)
            self.run_final_epa_test()

        self.filter_mislabels()
        self.sort_mislabels()
        self.write_mislabels()
        config.log.info("\nTotal mislabels: %d / %.2f %%", len(self.mislabels),
                        (float(len(self.mislabels)) / self.reftree_size * 100))
Ejemplo n.º 4
0
class LeaveOneTest:
    def __init__(self, config, args):
        self.cfg = config
        self.method = args.method
        self.minlw = args.min_lhw
        self.jplace_fname = args.jplace_fname
        self.ranktest = args.ranktest
        self.output_fname = args.output_dir + "/" + args.output_name

        # switch off branch length filter
        self.brlen_pv = 0.

        self.tmp_refaln = config.tmp_fname("%NAME%.refaln")
        self.reftree_lbl_fname = config.tmp_fname("%NAME%_lbl.tre")
        self.reftree_tax_fname = config.tmp_fname("%NAME%_tax.tre")
        self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
        self.reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre")

        try:
            self.refjson = RefJsonParser(config.refjson_fname, ver="1.2")
        except ValueError:
            print("Invalid json file format!")
            sys.exit()
        #validate input json format 
        self.refjson.validate()
        self.rate = self.refjson.get_rate()
        self.node_height = self.refjson.get_node_height()
        self.origin_taxonomy = self.refjson.get_origin_taxonomy()
        self.bid_taxonomy_map = self.refjson.get_bid_tanomomy_map()
        self.tax_tree = self.refjson.get_tax_tree()
        self.cfg.compress_patterns = self.refjson.get_pattern_compression()

        reftree_str = self.refjson.get_raxml_readable_tree()
        self.reftree = Tree(reftree_str)
        self.reftree_size = len(self.reftree.get_leaves())

        # IMPORTANT: set EPA heuristic rate based on tree size!                
        self.cfg.resolve_auto_settings(self.reftree_size)
        # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file        
        if self.cfg.epa_load_optmod:
            self.cfg.raxml_model = self.refjson.get_ratehet_model()
        
        self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.brlen_pv, self.rate, self.node_height)
        
        self.TAXONOMY_RANKS_COUNT = 10
        self.mislabels = []
        self.mislabels_cnt = [0] * self.TAXONOMY_RANKS_COUNT
        self.rank_mislabels = []
        self.rank_mislabels_cnt = [0] * self.TAXONOMY_RANKS_COUNT
        self.misrank_conf_map = {}

    def cleanup(self):
        FileUtils.remove_if_exists(self.tmp_refaln)

    def classify_seq(self, placement):
        edges = placement["p"]
        if len(edges) > 0:
            return self.classify_helper.classify_seq(edges, self.method, self.minlw)
        else:
            print "ERROR: no placements! something is definitely wrong!"

    def rank_level_name(self, uni_rank_level):
        return { 0:  ("?__", "Unknown"),
                 1: ("k__", "Kingdom"),
                 2: ("p__", "Phylum"),
                 3: ("c__", "Class"),
                 4: ("d__", "Subclass"),
                 5: ("o__", "Order"),
                 6: ("n__", "Suborder"),
                 7: ("f__", "Family"),
                 8: ("g__", "Genus"),
                 9: ("s__", "Species")
                }[uni_rank_level]
                
    def guess_rank_level(self, ranks, rank_level):
        rank_name = ranks[rank_level]
        
        real_level = 0
        
        # check common prefixes and suffixes
        if rank_name.startswith("k__") or rank_name.lower() in ["bacteria", "archaea", "eukaryota"]:
            real_level = 1
        elif rank_name.startswith("p__"):
            real_level = 2
        elif rank_name.startswith("c__"):
            real_level = 3
        elif rank_name.endswith("dae"):
            real_level = 4
        elif rank_name.startswith("o__") or rank_name.endswith("ales"):
            real_level = 5
        elif rank_name.endswith("neae"):
            real_level = 6
        elif rank_name.startswith("f__") or rank_name.endswith("ceae"):
            real_level = 7
        elif rank_name.startswith("g__"):
            real_level = 8
        elif rank_name.startswith("s__"):
            real_level = 9
            
        if real_level == 0:
            if rank_level == 0:    # kingdom
                real_level = 1
            else:
                parent_level = self.guess_rank_level(ranks, rank_level-1)
                real_level = parent_level + 1
                if len(ranks) < 8 and (real_level in [4,6]):
                    real_level += 1
                             
        return real_level
         
    def guess_rank_level_name(self, ranks, rank_level):
        real_level = self.guess_rank_level(ranks, rank_level)
        return self.rank_level_name(real_level)
        
    def check_seq_tax_labels(self, seq_name, orig_ranks, ranks, lws):
        mislabel_lvl = -1
        min_len = min(len(orig_ranks),len(ranks))
        for rank_lvl in range(min_len):
            if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]:
                mislabel_lvl = rank_lvl
                break

        if mislabel_lvl >= 0:
            real_lvl = self.guess_rank_level(orig_ranks, mislabel_lvl)
            mis_rec = {}
            mis_rec['name'] = EpacConfig.strip_ref_prefix(seq_name)
            mis_rec['orig_level'] = mislabel_lvl
            mis_rec['real_level'] = real_lvl
            mis_rec['level_name'] = self.rank_level_name(real_lvl)[1]
            mis_rec['inv_level'] = -1 * real_lvl  # just for sorting
            mis_rec['orig_ranks'] = orig_ranks
            mis_rec['ranks'] = ranks
            mis_rec['lws'] = lws
            mis_rec['conf'] = lws[mislabel_lvl]
            self.mislabels.append(mis_rec)
            
            return mis_rec
        else:
            return None

    def check_rank_tax_labels(self, rank_name, orig_ranks, ranks, lws):
        mislabel_lvl = -1
        min_len = min(len(orig_ranks),len(ranks))
        for rank_lvl in range(min_len):
            if ranks[rank_lvl] != Taxonomy.EMPTY_RANK and ranks[rank_lvl] != orig_ranks[rank_lvl]:
                mislabel_lvl = rank_lvl
                break

        if mislabel_lvl >= 0:
            real_lvl = self.guess_rank_level(orig_ranks, mislabel_lvl)
            mis_rec = {}
            mis_rec['name'] = rank_name
            mis_rec['orig_level'] = mislabel_lvl
            mis_rec['real_level'] = real_lvl
            mis_rec['level_name'] = self.rank_level_name(real_lvl)[1]
            mis_rec['inv_level'] = -1 * real_lvl  # just for sorting
            mis_rec['orig_ranks'] = orig_ranks
            mis_rec['ranks'] = ranks
            mis_rec['lws'] = lws
            mis_rec['conf'] = lws[mislabel_lvl]
            self.rank_mislabels.append(mis_rec)
               
            return mis_rec
        else:
            return None                

    def mis_rec_to_string_old(self, mis_rec):
        lvl = mis_rec['orig_level']
        output = mis_rec['name'] + "\t"
        output += "%s\t%s\t%s\t%.3f\n" % (mis_rec['level_name'], 
            mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl])
        output += ";".join(mis_rec['orig_ranks']) + "\n"
        output += ";".join(mis_rec['ranks']) + "\n"
        output += "\t".join(["%.3f" % conf for conf in mis_rec['lws']]) + "\n"
        return output

    def mis_rec_to_string(self, mis_rec):
        lvl = mis_rec['orig_level']
        output = mis_rec['name'] + "\t"
        output += "%s\t%s\t%s\t%.3f\t" % (mis_rec['level_name'], 
            mis_rec['orig_ranks'][lvl], mis_rec['ranks'][lvl], mis_rec['lws'][lvl])
        output += Taxonomy.lineage_str(mis_rec['orig_ranks']) + "\t"
        output += Taxonomy.lineage_str(mis_rec['ranks']) + "\t"
        output += ";".join(["%.3f" % conf for conf in mis_rec['lws']])
        if 'rank_conf' in mis_rec:
            output += "\t%.3f" % mis_rec['rank_conf']
        return output

    def sort_mislabels(self):
        self.mislabels = sorted(self.mislabels, key=itemgetter('inv_level', 'conf'), reverse=True)
        for mis_rec in self.mislabels:
            real_lvl = mis_rec["real_level"]
            self.mislabels_cnt[real_lvl] += 1
        
        if self.ranktest:
            self.rank_mislabels = sorted(self.rank_mislabels, key=itemgetter('inv_level', 'conf'), reverse=True)
            for mis_rec in self.rank_mislabels:
                real_lvl = mis_rec["real_level"]
                self.rank_mislabels_cnt[real_lvl] += 1
    
    def write_mislabels(self, final=True):
        if final:
            out_fname = "%s.mis" % self.output_fname
        else:
            out_fname = "%s.premis" % self.output_fname
        
        with open(out_fname, "w") as fo_all:
            fields = ["SeqID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"]
            if self.ranktest:
                fields += ["HigherRankMisplacedConfidence"]
            header = ";" + "\t".join(fields) + "\n"
            fo_all.write(header)
            if self.cfg.verbose and len(self.mislabels) > 0 and final:
                print "Mislabeled sequences:\n"
                print header 
            for mis_rec in self.mislabels:
                output = self.mis_rec_to_string(mis_rec)  + "\n"
                fo_all.write(output)
                if self.cfg.verbose and final:
                    print(output) 
                    
        if not final:
            return

        if self.ranktest:
            with open("%s.misrank" % self.output_fname, "w") as fo_all:
                fields = ["RankID", "MislabeledLevel", "OriginalLabel", "ProposedLabel", "Confidence", "OriginalTaxonomyPath", "ProposedTaxonomyPath", "PerRankConfidence"]
                header = ";" + "\t".join(fields)  + "\n"
                fo_all.write(header)
                if self.cfg.verbose  and len(self.rank_mislabels) > 0:
                    print "\nMislabeled higher ranks:\n"
                    print header 
                for mis_rec in self.rank_mislabels:
                    output = self.mis_rec_to_string(mis_rec) + "\n"
                    fo_all.write(output)
                    if self.cfg.verbose:
                        print(output) 

        print "Mislabels counts by ranks:"        
        with open("%s.stats" % self.output_fname, "w") as fo_stat:
            seq_sum = 0
            rank_sum = 0
            for i in range(1, self.TAXONOMY_RANKS_COUNT):
                rname = self.rank_level_name(i)[1].ljust(10)
                if self.mislabels_cnt[i] > 0 or i not in [4,6]:
                    seq_sum += self.mislabels_cnt[i]
                    output = "%s:\t%d" % (rname, seq_sum)
                    if self.ranktest:
                        rank_sum += self.rank_mislabels_cnt[i]
                        output += "\t%d" % rank_sum
                    fo_stat.write(output + "\n")
                    print(output) 
       
    def get_orig_ranks(self, seq_name):
        nodes = self.tax_tree.get_leaves_by_name(seq_name)
        if len(nodes) != 1:
            print "FATAL ERROR: Sequence %s is not found in the taxonomic tree, or is present more than once!" % seq_name
            sys.exit()
        seq_node = nodes[0]
        orig_ranks = Taxonomy.split_rank_uid(seq_node.up.name)
        return orig_ranks
    
    def run_leave_subtree_out_test(self):
        job_name = self.cfg.subst_name("l1out_rank_%NAME%")
#        if self.jplace_fname:
#            jp = EpaJsonParser(self.jplace_fname)
#        else:        

        #create file with subtrees
        rank_tips = {}
        rank_parent = {}
        for node in self.tax_tree.traverse("postorder"):
            if node.is_leaf() or node.is_root():
                continue
            tax_path = node.name
            ranks = Taxonomy.split_rank_uid(tax_path)
            rank_lvl = Taxonomy.lowest_assigned_rank_level(ranks)
            if rank_lvl < 2:
                continue
                
            parent_ranks = Taxonomy.split_rank_uid(node.up.name)
            parent_lvl = Taxonomy.lowest_assigned_rank_level(parent_ranks)
            if parent_lvl < 1:
                continue
            
            rank_seqs = node.get_leaf_names()
            rank_size = len(rank_seqs)
            if rank_size < 2 or rank_size > self.reftree_size-4:
                continue

#            print rank_lvl, "\t", tax_path, "\t", rank_seqs, "\n"
            rank_tips[tax_path] = node.get_leaf_names()
            rank_parent[tax_path] = parent_ranks
                
        subtree_list = rank_tips.items()
        
        if len(subtree_list) == 0:
            return 0
            
        subtree_list_file = self.cfg.tmp_fname("treelist_%NAME%.txt")
        with open(subtree_list_file, "w") as fout:
            for rank_name, tips in subtree_list:
                fout.write("%s\n" % " ".join(tips))
        
        jp_list = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, 
            mode="l1o_subtree", subtree_fname=subtree_list_file)

        subtree_count = 0
        for jp in jp_list:
            placements = jp.get_placement()
            for place in placements:
                ranks, lws = self.classify_seq(place)
                tax_path = subtree_list[subtree_count][0]
                orig_ranks = Taxonomy.split_rank_uid(tax_path)
                rank_level = Taxonomy.lowest_assigned_rank_level(orig_ranks)
                rank_prefix = self.guess_rank_level_name(orig_ranks, rank_level)[0]
                rank_name = orig_ranks[rank_level]
                if not rank_name.startswith(rank_prefix):
                    rank_name = rank_prefix + rank_name
                parent_ranks = rank_parent[tax_path]
#                print orig_ranks, "\n", parent_ranks, "\n", ranks, "\n"
                mis_rec = self.check_rank_tax_labels(rank_name, parent_ranks, ranks, lws)
                if mis_rec:
                    self.misrank_conf_map[tax_path] = mis_rec['conf']
                subtree_count += 1

        return subtree_count    
        
    def run_leave_seq_out_test(self):
        job_name = self.cfg.subst_name("l1out_seq_%NAME%")
        if self.jplace_fname:
            jp = EpaJsonParser(self.jplace_fname)
        else:        
            jp = self.raxml.run_epa(job_name, self.refalign_fname, self.reftree_fname, self.optmod_fname, mode="l1o_seq")

        placements = jp.get_placement()
        seq_count = 0
        for place in placements:
            seq_name = place["n"][0]
            
            # get original taxonomic label
            orig_ranks = self.get_orig_ranks(seq_name)

            # get EPA tax label
            ranks, lws = self.classify_seq(place)
            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws)
            # cross-check with higher rank mislabels
            if self.ranktest and mis_rec:
                rank_conf = 0
                for lvl in range(2,len(orig_ranks)):
                    tax_path = Taxonomy.get_rank_uid(orig_ranks, lvl)
                    if tax_path in self.misrank_conf_map:
                        rank_conf = max(rank_conf, self.misrank_conf_map[tax_path])
                mis_rec['rank_conf'] = rank_conf
            seq_count += 1

        return seq_count    
        
    def run_final_epa_test(self):
        self.reftree_outgroup = self.refjson.get_outgroup()
        tmp_reftree = self.reftree.copy() 
        tmp_taxtree = self.tax_tree.copy() 
        for mis_rec in self.mislabels:
            name = mis_rec['name']
            rname = EpacConfig.REF_SEQ_PREFIX + name

            leaf_nodes = tmp_reftree.get_leaves_by_name(rname)
            if len(leaf_nodes) > 0:
                leaf_nodes[0].delete()
            else:
                print "Node not found in the reference tree: %s" % rname

            leaf_nodes = tmp_taxtree.get_leaves_by_name(rname)
            if len(leaf_nodes) > 0:
                leaf_nodes[0].delete()
            else:
                print "Node not found in the taxonomic tree: %s" % rname

        # remove unifurcation at the root
        if len(tmp_reftree.children) == 1:
            tmp_reftree = tmp_reftree.children[0]
            
        self.mislabels = []

        th = TaxTreeHelper(self.origin_taxonomy, self.cfg)
        th.set_mf_rooted_tree(tmp_taxtree)
            
        self.run_epa_once(tmp_reftree, th)
            

    def run_epa_once(self, reftree, th):
        reftree_fname = self.cfg.tmp_fname("final_ref_%NAME%.tre")
        job_name = self.cfg.subst_name("final_epa_%NAME%")

        reftree.write(outfile=reftree_fname)

        # IMPORTANT: don't load the model, since it's invalid for the pruned true !!! 
        optmod_fname=""
        epa_result = self.raxml.run_epa(job_name, self.refalign_fname, reftree_fname, optmod_fname)
        reftree_epalbl_str = epa_result.get_std_newick_tree()        
        placements = epa_result.get_placement()
        
        # update branchid-taxonomy mapping to account for possible changes in branch numbering
        reftree_tax = Tree(reftree_epalbl_str)
        th.set_bf_unrooted_tree(reftree_tax)
        bid_tax_map = th.get_bid_taxonomy_map()
        
        cl = TaxClassifyHelper(self.cfg, bid_tax_map, self.brlen_pv, self.rate, self.node_height)

        for place in placements:
            seq_name = place["n"][0]

            # get original taxonomic label
            orig_ranks = self.get_orig_ranks(seq_name)
            # get EPA tax label
            ranks, lws = cl.classify_seq(place["p"])
            # check if they match
            mis_rec = self.check_seq_tax_labels(seq_name, orig_ranks, ranks, lws)

        if not self.cfg.debug:
            self.raxml.cleanup(job_name)
            FileUtils.remove_if_exists(reftree_fname)
            
    def run_test(self):
        self.raxml = RaxmlWrapper(self.cfg)

        print "Number of sequences in the reference: %d\n" % self.reftree_size

        self.refjson.get_raxml_readable_tree(self.reftree_fname)
        self.refalign_fname = self.refjson.get_alignment(self.tmp_refaln)        
        self.refjson.get_binary_model(self.optmod_fname)

        if self.ranktest:
            print "Running the leave-one-rank-out test...\n"
            subtree_count = self.run_leave_subtree_out_test()
            
        print "Running the leave-one-sequence-out test...\n"
        self.run_leave_seq_out_test()

        if len(self.mislabels) > 0:
            print "Leave-one-out test identified %d suspicious sequences; running final EPA test to check them...\n" % len(self.mislabels)
            self.write_mislabels(final=False)
            self.run_final_epa_test()

        self.sort_mislabels()
        self.write_mislabels()
        print "\nPercentage of mislabeled sequences: %.2f %%" % (float(len(self.mislabels)) / self.reftree_size * 100)

        if not self.cfg.debug:
            FileUtils.remove_if_exists(self.reftree_fname)
            FileUtils.remove_if_exists(self.optmod_fname)
            FileUtils.remove_if_exists(self.refalign_fname)
Ejemplo n.º 5
0
class EpaClassifier:
    def __init__(self, config, args):
        self.cfg = config
        self.jplace_fname = args.jplace_fname
        self.ignore_refalign = args.ignore_refalign
        
        self.tmp_refaln = config.tmp_fname("%NAME%.refaln")
        #here is the final alignment file for running EPA
        self.epa_alignment = config.tmp_fname("%NAME%.afa")
        self.hmmprofile = config.tmp_fname("%NAME%.hmmprofile")
        self.tmpquery = config.tmp_fname("%NAME%.tmpquery")
        self.noalign = config.tmp_fname("%NAME%.noalign")
        self.seqs = None
        
        assign_fname = args.output_name + ".assignment.txt"
        self.out_assign_fname = os.path.join(args.output_dir, assign_fname)
        jplace_fname = args.output_name + ".jplace"
        self.out_jplace_fname = os.path.join(args.output_dir, jplace_fname)

        try:
            self.refjson = RefJsonParser(config.refjson_fname)
        except ValueError:
            self.cfg.exit_user_error("Invalid json file format: %s" % config.refjson_fname)
        #validate input json format 
        self.refjson.validate()
        self.reftree = self.refjson.get_reftree()
        self.rate = self.refjson.get_rate()
        self.node_height = self.refjson.get_node_height()
        self.cfg.compress_patterns = self.refjson.get_pattern_compression()

        self.bid_taxonomy_map = self.refjson.get_branch_tax_map()
        if not self.bid_taxonomy_map:
            # old file format (before 1.6), need to rebuild this map from scratch
            th = TaxTreeHelper(self.cfg, self.refjson.get_origin_taxonomy())
            th.set_mf_rooted_tree(self.refjson.get_tax_tree())
            th.set_bf_unrooted_tree(self.refjson.get_reftree())
            self.bid_taxonomy_map = th.get_bid_taxonomy_map()        
        
        self.cfg.log.info("Loaded reference tree with %d taxa\n" % len(self.reftree.get_leaves()))

        self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height)
        
    def require_muscle(self):
        basepath = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(basepath + "/epac/bin/muscle"):
            errmsg = "The pipeline uses MUSCLE to merge alignments, please download the programm from:\n" + \
                     "http://www.drive5.com/muscle/downloads.htm\n" + \
                     "and specify path to your installation in the config file (sativa.cfg)\n"
            self.cfg.exit_user_error(errmsg)

    def require_hmmer(self):
        basepath = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(basepath + "/epac/bin/hmmbuild") or not os.path.exists(basepath + "/epac/bin/hmmalign"):
            errmsg = "The pipeline uses HAMMER to align the query seqeunces, please download the programm from:\n" + \
                     "http://hmmer.janelia.org/\n" + \
                     "and specify path to your installation in the config file (sativa.cfg)\n"
            self.cfg.exit_user_error(errmsg)

    def align_to_refenence(self, noalign, minp = 0.9):
        refaln = self.refjson.get_alignment(fout = self.tmp_refaln)
        fprofile = self.refjson.get_hmm_profile(self.hmmprofile)
        
        # if there is no hmmer profile in json file, build it from scratch          
        if not fprofile:
            hmm = hmmer(self.cfg, refaln)
            fprofile = hmm.build_hmm_profile()
    
        hm = hmmer(config = self.cfg, refalign = refaln , query = self.tmpquery, refprofile = fprofile, discard = noalign, seqs = self.seqs, minp = minp)
        self.epa_alignment = hm.align()

    def merge_alignment(self, query_seqs):
        refaln = self.refjson.get_alignment_list()
        with open(self.epa_alignment, "w") as fout:
            for seq in refaln:
                fout.write(">" + seq[0] + "\n" + seq[1] + "\n")
            for name, seq, comment, sid in query_seqs.iter_entries():
                fout.write(">" + name + "\n" + seq + "\n")


    def checkinput(self, query_fname, minp = 0.9):
        formats = ["fasta", "phylip", "iphylip", "phylip_relaxed", "iphylip_relaxed"]
        for fmt in formats:
            try:
                self.seqs = SeqGroup(sequences=query_fname, format = fmt)
                break
            except:
                self.cfg.log.debug("Guessing input format: not " + fmt)
        if self.seqs == None:
            self.cfg.exit_user_error("Invalid input file format: %s\nThe supported input formats are fasta and phylip" % query_fname)

        if self.ignore_refalign:
            self.cfg.log.info("Assuming query file contains reference sequences, skipping the alignment step...\n")
            self.query_count = 0
            with open(self.epa_alignment, "w") as fout:
                for name, seq, comment, sid in self.seqs.iter_entries():
                    ref_name = self.refjson.get_corr_seqid(EpacConfig.REF_SEQ_PREFIX + name)
                    if ref_name in self.refjson.get_sequences_names():
                        seq_name = ref_name
                    else:
                        seq_name = EpacConfig.QUERY_SEQ_PREFIX + name
                        self.query_count += 1
                    fout.write(">" + seq_name + "\n" + seq + "\n")
            return
            
        self.query_count = len(self.seqs)
            
        # add query seq name prefix to avoid confusion between reference and query sequences
        self.seqs.add_name_prefix(EpacConfig.QUERY_SEQ_PREFIX)
        
        self.seqs.write(format="fasta", outfile=self.tmpquery)
        self.cfg.log.info("Checking if query sequences are aligned ...")
        entries = self.seqs.get_entries()
        seql = len(entries[0][1])
        aligned = True
        for entri in entries[1:]:
            l = len(entri[1])
            if not seql == l:
                aligned = False
                break
        
        if aligned and len(self.seqs) > 1:
            self.cfg.log.info("Query sequences are aligned")
            refalnl = self.refjson.get_alignment_length()
            if refalnl == seql:
                self.cfg.log.info("Merging query alignment with reference alignment")
                self.merge_alignment(self.seqs)
            else:
                self.cfg.log.info("Merging query alignment with reference alignment using MUSCLE")
                self.require_muscle()
                refaln = self.refjson.get_alignment(fout = self.tmp_refaln)
                m = muscle(self.cfg)
                self.epa_alignment = m.merge(refaln, self.tmpquery)
        else:
            self.cfg.log.info("Query sequences are not aligned")
            self.cfg.log.info("Align query sequences to the reference alignment using HMMER")
            self.require_hmmer()
            self.align_to_refenence(self.noalign, minp = minp)

    def print_ranks(self, rks, confs, minlw = 0.0):
        uncorr_ranks = self.refjson.get_uncorr_ranks(rks)
        ss = ""
        css = ""
        for i in range(len(uncorr_ranks)):
            conf = confs[i]
            if conf == confs[0] and confs[0] >=0.99:
                conf = 1.0
            if conf >= minlw:
                ss = ss + uncorr_ranks[i] + ";"
                css = css + "{0:.3f}".format(conf) + ";"
            else:
                break
        if ss == "":
            return None
        else:
            return ss[:-1] + "\t" + css[:-1]


    def classify(self, query_fname, minp = 0.9, ptp = False):
        if self.jplace_fname:
            jp = EpaJsonParser(self.jplace_fname)
        else:        
            self.checkinput(query_fname, minp)

            self.cfg.log.info("Running RAxML-EPA to place %d query sequences...\n" % self.query_count)
            raxml = RaxmlWrapper(config)
            reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre")
            self.refjson.get_raxml_readable_tree(reftree_fname)
            optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
            self.refjson.get_binary_model(optmod_fname)
            job_name = self.cfg.subst_name("epa_%NAME%")

            reftree_str = self.refjson.get_raxml_readable_tree()
            reftree = Tree(reftree_str)

            self.reftree_size = len(reftree.get_leaves())

            # IMPORTANT: set EPA heuristic rate based on tree size!                
            self.cfg.resolve_auto_settings(self.reftree_size)
            # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file        
            if self.cfg.epa_load_optmod:
                self.cfg.raxml_model = self.refjson.get_ratehet_model()

            reduced_align_fname = raxml.reduce_alignment(self.epa_alignment)

            jp = raxml.run_epa(job_name, reduced_align_fname, reftree_fname, optmod_fname)
            
            raxml.copy_epa_jplace(job_name, self.out_jplace_fname, move=True)
        
        self.cfg.log.info("Assigning taxonomic labels based on EPA placements...\n")
 
        placements = jp.get_placement()
        
        if self.out_assign_fname:
            fo = open(self.out_assign_fname, "w")
        else:
            fo = None
        
        noassign_list = []
        for place in placements:
            taxon_name = place["n"][0]
            origin_taxon_name = EpacConfig.strip_query_prefix(taxon_name)
            edges = place["p"]
            if len(edges) > 0:
                ranks, lws = self.classify_helper.classify_seq(edges)
                
                isnovo = self.novelty_check(place_edge = str(edges[0][0]), ranks=ranks, lws=lws)
                rankout = self.print_ranks(ranks, lws, self.cfg.min_lhw)
                
                if rankout == None:
                    noassign_list.append(origin_taxon_name)
                else:
                    output = "%s\t%s\t" % (origin_taxon_name, rankout)
                    if isnovo: 
                        output += "*"
                    else:
                        output +="o"
                    if self.cfg.verbose:
                        print(output) 
                    if fo:
                        fo.write(output + "\n")
            else:
                noassign_list.append(origin_taxon_name)
        
        if os.path.exists(self.noalign):
            with open(self.noalign) as fnoa:
                lines = fnoa.readlines()
                for line in lines:
                    taxon_name = line.strip()[1:]
                    origin_taxon_name = EpacConfig.strip_query_prefix(taxon_name)
                    noassign_list.append(origin_taxon_name)
                        
        for taxon_name in noassign_list:
            output = "%s\t\t\t?" % origin_taxon_name
            if self.cfg.verbose:
                print(output)
            if fo:
                fo.write(output + "\n")
        
        if fo:
            fo.close()

        #############################################
        #
        # EPA-PTP species delimitation
        #
        #############################################
        if ptp:
            full_aln = SeqGroup(self.epa_alignment)
            species_list = epa_2_ptp(epa_jp = jp, ref_jp = self.refjson, full_alignment = full_aln, min_lw = 0.5, debug = self.cfg.debug)
            
            self.cfg.log.debug("Species clusters:")
 
            if fout:
                fo2 = open(fout+".species", "w")
            else:
                fo2 = None

            for sp_cluster in species_list:
                translated_taxa = []
                for taxon in sp_cluster:
                    origin_taxon_name = EpacConfig.strip_query_prefix(taxon)
                    translated_taxa.append(origin_taxon_name)
                s = ",".join(translated_taxa)
                if fo2:
                    fo2.write(s + "\n")
                self.cfg.log.debug(s)

            if fo2:
                fo2.close()
        #############################################
        
    def novelty_check(self, place_edge, ranks, lws):
        """If the taxonomic assignment is not assigned to the genus level, 
        we need to check if it is due to the incomplete reference taxonomy or 
        it is likely to be something new:
        
        1. If the final ranks are assinged because of lw cut, that means with samller lw
        the ranks can be further assinged to lowers. This indicate the undetermined ranks 
        in the assignment is not due to the incomplete reference taxonomy, so the query 
        sequence is likely to be something new.
        
        2. Otherwise We check all leaf nodes' immediate lower rank below this ml placement point, 
        if they are not empty, output all ranks and indicate this could be novelty.
        """
        
        lowrank = 0
        for i in range(len(ranks)):
            if i < 6:
                """above genus level"""
                rk = ranks[i]
                lw = lws[i]
                if rk == "-":
                    break
                else:
                    lowrank = lowrank + 1
                    if lw >=0 and lw < self.cfg.min_lhw:
                        return True
        
        if lowrank >= 5 and lowrank < len(ranks) and not ranks[lowrank] == "-":
            return False
        else:
            placenode = self.reftree.search_nodes(B = place_edge)[0]
            if placenode.is_leaf():
                return False
            else:
                leafnodes = placenode.get_leaves()
                flag = True
                for leaf in leafnodes:
                    br_num = leaf.B
                    branks = self.bid_taxonomy_map[br_num]
                    if branks[lowrank] == "-":
                        flag = False
                        break
                        
                return flag