def export_ref_alignment(self):
        """This function transforms the input alignment in the following way:
           1. Filter out sequences which are not part of the reference tree
           2. Add sequence name prefix (r_)"""
        in_file = self.cfg.align_fname
        ref_seqs = None
        formats = ["fasta", "phylip", "iphylip", "phylip_relaxed", "iphylip_relaxed"]
        for fmt in formats:
            try:
                ref_seqs = SeqGroup(sequences=in_file, format = fmt)
                break
            except:
                if self.cfg.debug:
                    print("Guessing input format: not " + fmt)
        if ref_seqs == None:
            print("Invalid input file format: %s" % in_file)
            print("The supported input formats are fasta and phylip")
            sys.exit()

        self.refalign_fname = self.cfg.tmp_fname("%NAME%_matrix.afa")
        with open(self.refalign_fname, "w") as fout:
            for name, seq, comment, sid in ref_seqs.iter_entries():
                seq_name = EpacConfig.REF_SEQ_PREFIX + name
                if seq_name in self.reftree_ids:
                    fout.write(">" + seq_name + "\n" + seq + "\n")
Example #2
0
 def get_hmm_refalignment(self):
     sites = []
     hmp = open(self.refprofile)
     l = hmp.readline()
     start = False
     while l!="":
         if l.startswith("//"):
             break
         if start:
             l = l.strip()
             ll = l.split()
             usedsite = int(ll[5])
             sites.append(usedsite)
             l = hmp.readline()
             l = hmp.readline()
         else:
             if l.startswith("HMM "):
                 start = True
                 l = hmp.readline()
                 l = hmp.readline()
                 l = hmp.readline()
                 l = hmp.readline()
         l = hmp.readline()
     hmp.close()
     align = SeqGroup(self.refalign)
     fout = open(self.trimed, "w")
     for entr in align.get_entries():
         fout.write(">" + entr[0] + "\n")
         for pos in sites:
             fout.write(entr[1][pos-1])
         fout.write("\n")
     fout.close()
     return self.trimed, len(sites)
Example #3
0
    def checkinput(self, query_fname, minp=0.9):
        formats = [
            "fasta", "phylip", "iphylip", "phylip_relaxed", "iphylip_relaxed"
        ]
        for fmt in formats:
            try:
                self.seqs = SeqGroup(sequences=query_fname, format=fmt)
                break
            except:
                self.cfg.log.debug("Guessing input format: not " + fmt)
        if self.seqs == None:
            self.cfg.exit_user_error(
                "Invalid input file format: %s\nThe supported input formats are fasta and phylip"
                % query_fname)

        if self.ignore_refalign:
            self.cfg.log.info(
                "Assuming query file contains reference sequences, skipping the alignment step...\n"
            )
            self.write_combined_alignment()
            return

        self.query_count = len(self.seqs)

        # add query seq name prefix to avoid confusion between reference and query sequences
        self.seqs.add_name_prefix(EpacConfig.QUERY_SEQ_PREFIX)

        self.seqs.write(format="fasta", outfile=self.tmpquery)
        self.cfg.log.info("Checking if query sequences are aligned ...")
        entries = self.seqs.get_entries()
        seql = len(entries[0][1])
        aligned = True
        for entri in entries[1:]:
            l = len(entri[1])
            if not seql == l:
                aligned = False
                break

        if aligned and len(self.seqs) > 1:
            self.cfg.log.info("Query sequences are aligned")
            refalnl = self.refjson.get_alignment_length()
            if refalnl == seql:
                self.cfg.log.info(
                    "Merging query alignment with reference alignment")
                self.merge_alignment(self.seqs)
            else:
                self.cfg.log.info(
                    "Merging query alignment with reference alignment using MUSCLE"
                )
                self.require_muscle()
                refaln = self.refjson.get_alignment(fout=self.tmp_refaln)
                m = muscle(self.cfg)
                self.epa_alignment = m.merge(refaln, self.tmpquery)
        else:
            self.cfg.log.info("Query sequences are not aligned")
            self.cfg.log.info(
                "Align query sequences to the reference alignment using HMMER")
            self.require_hmmer()
            self.align_to_refenence(self.noalign, minp=minp)
Example #4
0
 def load_reduced_refalign(self):
     formats = ["fasta", "phylip_relaxed"]
     for fmt in formats:
         try:
             self.reduced_refalign_seqs = SeqGroup(
                 sequences=self.reduced_refalign_fname, format=fmt)
             break
         except:
             pass
     if self.reduced_refalign_seqs == None:
         errmsg = "FATAL ERROR: Invalid input file format in %s! (load_reduced_refalign)" % self.reduced_refalign_fname
         self.cfg.exit_fatal_error(errmsg)
Example #5
0
 def load_alignment(self):
     in_file = self.cfg.align_fname
     self.input_seqs = None
     formats = [
         "fasta", "phylip_relaxed", "iphylip_relaxed", "phylip", "iphylip"
     ]
     for fmt in formats:
         try:
             self.input_seqs = SeqGroup(sequences=in_file, format=fmt)
             break
         except:
             self.cfg.log.debug("Guessing input format: not " + fmt)
     if self.input_seqs == None:
         self.cfg.exit_user_error(
             "Invalid input file format: %s\nThe supported input formats are fasta and phylip"
             % in_file)
Example #6
0
    def run_ptp(self, jp):
        full_aln = SeqGroup(self.epa_alignment)
        species_list = epa_2_ptp(epa_jp=jp,
                                 ref_jp=self.refjson,
                                 full_alignment=full_aln,
                                 min_lw=0.5,
                                 debug=self.cfg.debug)

        self.cfg.log.debug("Species clusters:")

        if fout:
            fo2 = open(fout + ".species", "w")
        else:
            fo2 = None

        for sp_cluster in species_list:
            translated_taxa = []
            for taxon in sp_cluster:
                origin_taxon_name = EpacConfig.strip_query_prefix(taxon)
                translated_taxa.append(origin_taxon_name)
            s = ",".join(translated_taxa)
            if fo2:
                fo2.write(s + "\n")
            self.cfg.log.debug(s)

        if fo2:
            fo2.close()
    def checkinput(self, query_fname, minp = 0.9):
        formats = ["fasta", "phylip", "iphylip", "phylip_relaxed", "iphylip_relaxed"]
        for fmt in formats:
            try:
                self.seqs = SeqGroup(sequences=query_fname, format = fmt)
                break
            except:
                print("Guessing input format: not " + fmt)
        if self.seqs == None:
            print("Invalid input file format!")
            print("The supported input formats are fasta and phylip")
            sys.exit()

        if self.ignore_refalign:
            print("Assuming query file contains reference sequences, skipping the alignment step...")
            with open(self.epa_alignment, "w") as fout:
                for name, seq, comment, sid in self.seqs.iter_entries():
                    ref_name = self.REF_PREFIX + name
                    if ref_name in self.refjson.get_sequences_names():
                        seq_name = ref_name
                    else:
                        seq_name = EpacConfig.QUERY_SEQ_PREFIX + name
                    fout.write(">" + seq_name + "\n" + seq + "\n")
            return
            
        # add query seq name prefix to avoid confusion between reference and query sequences
        self.seqs.add_name_prefix(EpacConfig.QUERY_SEQ_PREFIX)
        
        self.seqs.write(format="fasta", outfile=self.tmpquery)
        print("Checking if query sequences are aligned ...")
        entries = self.seqs.get_entries()
        seql = len(entries[0][1])
        aligned = True
        for entri in entries[1:]:
            l = len(entri[1])
            if not seql == l:
                aligned = False
                break
        
        if aligned and len(self.seqs) > 1:
            print("Query sequences are aligned")
            refalnl = self.refjson.get_alignment_length()
            if refalnl == seql:
                print("Merging query alignment with reference alignment")
                self.merge_alignment(self.seqs)
            else:
                print("Merging query alignment with reference alignment using MUSCLE")
                require_muscle()
                refaln = self.refjson.get_alignment(fout = self.tmp_refaln)
                m = muscle(self.cfg)
                self.epa_alignment = m.merge(refaln, self.tmpquery)
        else:
            print("Query sequences are not aligned")
            print("Align query sequences to the reference alignment using HMMER")
            require_hmmer()
            self.align_to_refenence(self.noalign, minp = minp)
        
        print("Running EPA ......")
        print("")
Example #8
0
def epa_2_ptp(epa_jp, ref_jp, full_alignment, min_lw = 0.5, debug = False):
    placements = epa_jp.get_placement()
    reftree = Tree(epa_jp.get_std_newick_tree())
    allnodes = reftree.get_descendants()
    species_list = []
    
    placemap = {}
    """find how many edges are used for placement, and create a map to store """
    for placement in placements:
        edges = placement["p"]
        curredge = edges[0][0]
        lw = edges[0][2] 
        if lw >= min_lw:
            placemap[curredge] = placemap.get(curredge, [])

    """group taxa name by placement branch"""
    for placement in placements:
        edges = placement["p"]
        taxa_names = placement["n"]
        curredge = edges[0][0]
        lw = edges[0][2] 
        if lw >= min_lw:
            a = placemap[curredge] 
            a.extend(taxa_names)
            placemap[curredge]  = a

    groups = placemap.items()
    cnt_leaf = 0
    cnt_inode = 0
    
    """check each placement edge""" 
    for i,item in enumerate(groups):
        place_branch_name = item[0]
        seqset = item[1]
        if len(seqset) < 4:
            species_list.append(seqset)
        else:
            branch_alignment = SeqGroup()
            for taxa in seqset:
                branch_alignment.set_seq(taxa, full_alignment.get_seq(taxa))
            species = build_tree_run_ptp(branch_alignment, ref_jp.get_rate())
            species_list.extend(species)
    return species_list
Example #9
0
def merge_alignment(aln1, aln2, fout, numsites):
    seqs1 = SeqGroup(aln1)
    seqs2 = SeqGroup(aln2)
    if len(seqs1) == 0 or len(seqs2) == 0:
        print("No sequences aligned! ")
        sys.exit()
    with open(fout, "w") as fo:
        for seq in seqs1.iter_entries():
            if len(seq[1].strip()) == numsites:
                fo.write(">" + seq[0] + "\n" + seq[1] + "\n")
            else:
                print("Error in alignment ....")
                sys.exit()
        for seq in seqs2.iter_entries():
            if len(seq[1].strip()) == numsites:
                fo.write(">" + seq[0] + "\n" + seq[1] + "\n")
            else:
                print("Error in alignment ....")
                sys.exit()
Example #10
0
 def load_reduced_refalign(self):
     formats = ["fasta", "phylip_relaxed"]
     for fmt in formats:
         try:
             self.reduced_refalign_seqs = SeqGroup(sequences=self.reduced_refalign_fname, format = fmt)
             break
         except:
             pass
     if self.reduced_refalign_seqs == None:
         print("FATAL ERROR: Invalid input file format in %s! (load_reduced_refalign)" % self.reduced_refalign_fname)
         sys.exit()
Example #11
0
 def load_reduced_refalign(self):
     formats = ["fasta", "phylip_relaxed"]
     for fmt in formats:
         try:
             self.reduced_refalign_seqs = SeqGroup(sequences=self.reduced_refalign_fname, format = fmt)
             break
         except:
             pass
     if self.reduced_refalign_seqs == None:
         errmsg = "FATAL ERROR: Invalid input file format in %s! (load_reduced_refalign)" % self.reduced_refalign_fname
         self.cfg.exit_fatal_error(errmsg)
Example #12
0
 def load_alignment(self):
     in_file = self.cfg.align_fname
     self.input_seqs = None
     formats = ["fasta", "phylip_relaxed", "iphylip_relaxed", "phylip", "iphylip"]
     for fmt in formats:
         try:
             self.input_seqs = SeqGroup(sequences=in_file, format = fmt)
             break
         except:
             self.cfg.log.debug("Guessing input format: not " + fmt)
     if self.input_seqs == None:
         self.cfg.exit_user_error("Invalid input file format: %s\nThe supported input formats are fasta and phylip" % in_file)
Example #13
0
    def setUp(self):
        cfg = EpacTrainerConfig()
        cfg.debug = True
        testfile_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                    "testfiles")
        tax_fname = os.path.join(testfile_dir, "test.tax")
        phy_fname = os.path.join(testfile_dir, "test.phy")
        tax = Taxonomy(EpacConfig.REF_SEQ_PREFIX, tax_fname)
        seqs = SeqGroup(sequences=phy_fname, format="phylip")
        self.inval = InputValidator(cfg, tax, seqs, False)

        self.expected_mis_ids = ["Missing1", "Missing2"]
        self.expected_dups = ["DupSeq(01)", "DupSeq02"]
        self.expected_merges = [
            self.inval.taxonomy.seq_rank_id(sid) for sid in self.expected_dups
        ]
Example #14
0
 def dummy(self, reftree, alignment):
     seqs = SeqGroup(sequences=alignment, format='fasta')
     entries = seqs.get_entries()
     seq0 = entries[0][1]
     dummyseq = seq0[:-50] + "A"*50
     seqs.set_seq(name = "dummy", seq = dummyseq)
     fout = self.tmppath + "/dummy" + self.name + ".fa"
     seqs.write(format='fasta', outfile=fout) 
     self.run(reftree = reftree, alignment = fout)
     self.clean()
     os.remove(fout)
     return self.tmppath + "/" + "RAxML_portableTree." + self.name + ".jplace"
Example #15
0
 def get_ref_alignment(self):
     entries = self.jdata["sequences"]
     alignment = SeqGroup()
     for entr in entries:
         alignment.set_seq(entr[0], entr[1])
     return alignment
Example #16
0
class EpaClassifier:
    def __init__(self, config, args):
        self.cfg = config
        self.jplace_fname = args.jplace_fname
        self.ignore_refalign = args.ignore_refalign
        
        self.tmp_refaln = config.tmp_fname("%NAME%.refaln")
        #here is the final alignment file for running EPA
        self.epa_alignment = config.tmp_fname("%NAME%.afa")
        self.hmmprofile = config.tmp_fname("%NAME%.hmmprofile")
        self.tmpquery = config.tmp_fname("%NAME%.tmpquery")
        self.noalign = config.tmp_fname("%NAME%.noalign")
        self.seqs = None
        
        assign_fname = args.output_name + ".assignment.txt"
        self.out_assign_fname = os.path.join(args.output_dir, assign_fname)
        jplace_fname = args.output_name + ".jplace"
        self.out_jplace_fname = os.path.join(args.output_dir, jplace_fname)

        try:
            self.refjson = RefJsonParser(config.refjson_fname)
        except ValueError:
            self.cfg.exit_user_error("Invalid json file format: %s" % config.refjson_fname)
        #validate input json format 
        self.refjson.validate()
        self.reftree = self.refjson.get_reftree()
        self.rate = self.refjson.get_rate()
        self.node_height = self.refjson.get_node_height()
        self.cfg.compress_patterns = self.refjson.get_pattern_compression()

        self.bid_taxonomy_map = self.refjson.get_branch_tax_map()
        if not self.bid_taxonomy_map:
            # old file format (before 1.6), need to rebuild this map from scratch
            th = TaxTreeHelper(self.cfg, self.refjson.get_origin_taxonomy())
            th.set_mf_rooted_tree(self.refjson.get_tax_tree())
            th.set_bf_unrooted_tree(self.refjson.get_reftree())
            self.bid_taxonomy_map = th.get_bid_taxonomy_map()        
        
        self.cfg.log.info("Loaded reference tree with %d taxa\n" % len(self.reftree.get_leaves()))

        self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, self.rate, self.node_height)
        
    def require_muscle(self):
        basepath = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(basepath + "/epac/bin/muscle"):
            errmsg = "The pipeline uses MUSCLE to merge alignments, please download the programm from:\n" + \
                     "http://www.drive5.com/muscle/downloads.htm\n" + \
                     "and specify path to your installation in the config file (sativa.cfg)\n"
            self.cfg.exit_user_error(errmsg)

    def require_hmmer(self):
        basepath = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(basepath + "/epac/bin/hmmbuild") or not os.path.exists(basepath + "/epac/bin/hmmalign"):
            errmsg = "The pipeline uses HAMMER to align the query seqeunces, please download the programm from:\n" + \
                     "http://hmmer.janelia.org/\n" + \
                     "and specify path to your installation in the config file (sativa.cfg)\n"
            self.cfg.exit_user_error(errmsg)

    def align_to_refenence(self, noalign, minp = 0.9):
        refaln = self.refjson.get_alignment(fout = self.tmp_refaln)
        fprofile = self.refjson.get_hmm_profile(self.hmmprofile)
        
        # if there is no hmmer profile in json file, build it from scratch          
        if not fprofile:
            hmm = hmmer(self.cfg, refaln)
            fprofile = hmm.build_hmm_profile()
    
        hm = hmmer(config = self.cfg, refalign = refaln , query = self.tmpquery, refprofile = fprofile, discard = noalign, seqs = self.seqs, minp = minp)
        self.epa_alignment = hm.align()

    def merge_alignment(self, query_seqs):
        refaln = self.refjson.get_alignment_list()
        with open(self.epa_alignment, "w") as fout:
            for seq in refaln:
                fout.write(">" + seq[0] + "\n" + seq[1] + "\n")
            for name, seq, comment, sid in query_seqs.iter_entries():
                fout.write(">" + name + "\n" + seq + "\n")


    def checkinput(self, query_fname, minp = 0.9):
        formats = ["fasta", "phylip", "iphylip", "phylip_relaxed", "iphylip_relaxed"]
        for fmt in formats:
            try:
                self.seqs = SeqGroup(sequences=query_fname, format = fmt)
                break
            except:
                self.cfg.log.debug("Guessing input format: not " + fmt)
        if self.seqs == None:
            self.cfg.exit_user_error("Invalid input file format: %s\nThe supported input formats are fasta and phylip" % query_fname)

        if self.ignore_refalign:
            self.cfg.log.info("Assuming query file contains reference sequences, skipping the alignment step...\n")
            self.query_count = 0
            with open(self.epa_alignment, "w") as fout:
                for name, seq, comment, sid in self.seqs.iter_entries():
                    ref_name = self.refjson.get_corr_seqid(EpacConfig.REF_SEQ_PREFIX + name)
                    if ref_name in self.refjson.get_sequences_names():
                        seq_name = ref_name
                    else:
                        seq_name = EpacConfig.QUERY_SEQ_PREFIX + name
                        self.query_count += 1
                    fout.write(">" + seq_name + "\n" + seq + "\n")
            return
            
        self.query_count = len(self.seqs)
            
        # add query seq name prefix to avoid confusion between reference and query sequences
        self.seqs.add_name_prefix(EpacConfig.QUERY_SEQ_PREFIX)
        
        self.seqs.write(format="fasta", outfile=self.tmpquery)
        self.cfg.log.info("Checking if query sequences are aligned ...")
        entries = self.seqs.get_entries()
        seql = len(entries[0][1])
        aligned = True
        for entri in entries[1:]:
            l = len(entri[1])
            if not seql == l:
                aligned = False
                break
        
        if aligned and len(self.seqs) > 1:
            self.cfg.log.info("Query sequences are aligned")
            refalnl = self.refjson.get_alignment_length()
            if refalnl == seql:
                self.cfg.log.info("Merging query alignment with reference alignment")
                self.merge_alignment(self.seqs)
            else:
                self.cfg.log.info("Merging query alignment with reference alignment using MUSCLE")
                self.require_muscle()
                refaln = self.refjson.get_alignment(fout = self.tmp_refaln)
                m = muscle(self.cfg)
                self.epa_alignment = m.merge(refaln, self.tmpquery)
        else:
            self.cfg.log.info("Query sequences are not aligned")
            self.cfg.log.info("Align query sequences to the reference alignment using HMMER")
            self.require_hmmer()
            self.align_to_refenence(self.noalign, minp = minp)

    def print_ranks(self, rks, confs, minlw = 0.0):
        uncorr_ranks = self.refjson.get_uncorr_ranks(rks)
        ss = ""
        css = ""
        for i in range(len(uncorr_ranks)):
            conf = confs[i]
            if conf == confs[0] and confs[0] >=0.99:
                conf = 1.0
            if conf >= minlw:
                ss = ss + uncorr_ranks[i] + ";"
                css = css + "{0:.3f}".format(conf) + ";"
            else:
                break
        if ss == "":
            return None
        else:
            return ss[:-1] + "\t" + css[:-1]


    def classify(self, query_fname, minp = 0.9, ptp = False):
        if self.jplace_fname:
            jp = EpaJsonParser(self.jplace_fname)
        else:        
            self.checkinput(query_fname, minp)

            self.cfg.log.info("Running RAxML-EPA to place %d query sequences...\n" % self.query_count)
            raxml = RaxmlWrapper(config)
            reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre")
            self.refjson.get_raxml_readable_tree(reftree_fname)
            optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
            self.refjson.get_binary_model(optmod_fname)
            job_name = self.cfg.subst_name("epa_%NAME%")

            reftree_str = self.refjson.get_raxml_readable_tree()
            reftree = Tree(reftree_str)

            self.reftree_size = len(reftree.get_leaves())

            # IMPORTANT: set EPA heuristic rate based on tree size!                
            self.cfg.resolve_auto_settings(self.reftree_size)
            # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file        
            if self.cfg.epa_load_optmod:
                self.cfg.raxml_model = self.refjson.get_ratehet_model()

            reduced_align_fname = raxml.reduce_alignment(self.epa_alignment)

            jp = raxml.run_epa(job_name, reduced_align_fname, reftree_fname, optmod_fname)
            
            raxml.copy_epa_jplace(job_name, self.out_jplace_fname, move=True)
        
        self.cfg.log.info("Assigning taxonomic labels based on EPA placements...\n")
 
        placements = jp.get_placement()
        
        if self.out_assign_fname:
            fo = open(self.out_assign_fname, "w")
        else:
            fo = None
        
        noassign_list = []
        for place in placements:
            taxon_name = place["n"][0]
            origin_taxon_name = EpacConfig.strip_query_prefix(taxon_name)
            edges = place["p"]
            if len(edges) > 0:
                ranks, lws = self.classify_helper.classify_seq(edges)
                
                isnovo = self.novelty_check(place_edge = str(edges[0][0]), ranks=ranks, lws=lws)
                rankout = self.print_ranks(ranks, lws, self.cfg.min_lhw)
                
                if rankout == None:
                    noassign_list.append(origin_taxon_name)
                else:
                    output = "%s\t%s\t" % (origin_taxon_name, rankout)
                    if isnovo: 
                        output += "*"
                    else:
                        output +="o"
                    if self.cfg.verbose:
                        print(output) 
                    if fo:
                        fo.write(output + "\n")
            else:
                noassign_list.append(origin_taxon_name)
        
        if os.path.exists(self.noalign):
            with open(self.noalign) as fnoa:
                lines = fnoa.readlines()
                for line in lines:
                    taxon_name = line.strip()[1:]
                    origin_taxon_name = EpacConfig.strip_query_prefix(taxon_name)
                    noassign_list.append(origin_taxon_name)
                        
        for taxon_name in noassign_list:
            output = "%s\t\t\t?" % origin_taxon_name
            if self.cfg.verbose:
                print(output)
            if fo:
                fo.write(output + "\n")
        
        if fo:
            fo.close()

        #############################################
        #
        # EPA-PTP species delimitation
        #
        #############################################
        if ptp:
            full_aln = SeqGroup(self.epa_alignment)
            species_list = epa_2_ptp(epa_jp = jp, ref_jp = self.refjson, full_alignment = full_aln, min_lw = 0.5, debug = self.cfg.debug)
            
            self.cfg.log.debug("Species clusters:")
 
            if fout:
                fo2 = open(fout+".species", "w")
            else:
                fo2 = None

            for sp_cluster in species_list:
                translated_taxa = []
                for taxon in sp_cluster:
                    origin_taxon_name = EpacConfig.strip_query_prefix(taxon)
                    translated_taxa.append(origin_taxon_name)
                s = ",".join(translated_taxa)
                if fo2:
                    fo2.write(s + "\n")
                self.cfg.log.debug(s)

            if fo2:
                fo2.close()
        #############################################
        
    def novelty_check(self, place_edge, ranks, lws):
        """If the taxonomic assignment is not assigned to the genus level, 
        we need to check if it is due to the incomplete reference taxonomy or 
        it is likely to be something new:
        
        1. If the final ranks are assinged because of lw cut, that means with samller lw
        the ranks can be further assinged to lowers. This indicate the undetermined ranks 
        in the assignment is not due to the incomplete reference taxonomy, so the query 
        sequence is likely to be something new.
        
        2. Otherwise We check all leaf nodes' immediate lower rank below this ml placement point, 
        if they are not empty, output all ranks and indicate this could be novelty.
        """
        
        lowrank = 0
        for i in range(len(ranks)):
            if i < 6:
                """above genus level"""
                rk = ranks[i]
                lw = lws[i]
                if rk == "-":
                    break
                else:
                    lowrank = lowrank + 1
                    if lw >=0 and lw < self.cfg.min_lhw:
                        return True
        
        if lowrank >= 5 and lowrank < len(ranks) and not ranks[lowrank] == "-":
            return False
        else:
            placenode = self.reftree.search_nodes(B = place_edge)[0]
            if placenode.is_leaf():
                return False
            else:
                leafnodes = placenode.get_leaves()
                flag = True
                for leaf in leafnodes:
                    br_num = leaf.B
                    branks = self.bid_taxonomy_map[br_num]
                    if branks[lowrank] == "-":
                        flag = False
                        break
                        
                return flag
class EpaClassifier:
    def __init__(self, config, args):
        self.cfg = config
        self.jplace_fname = args.jplace_fname
        self.ignore_refalign = args.ignore_refalign
        
        self.tmp_refaln = config.tmp_fname("%NAME%.refaln")
        #here is the final alignment file for running EPA
        self.epa_alignment = config.tmp_fname("%NAME%.afa")
        self.hmmprofile = config.tmp_fname("%NAME%.hmmprofile")
        self.tmpquery = config.tmp_fname("%NAME%.tmpquery")
        self.noalign = config.tmp_fname("%NAME%.noalign")
        self.seqs = None

        try:
            self.refjson = RefJsonParser(config.refjson_fname)
        except ValueError:
            print("Invalid json file format!")
            sys.exit()
        #validate input json format 
        self.refjson.validate()
        self.bid_taxonomy_map = self.refjson.get_bid_tanomomy_map()
        self.reftree = self.refjson.get_reftree()
        self.rate = self.refjson.get_rate()
        self.node_height = self.refjson.get_node_height()
        self.cfg.compress_patterns = self.refjson.get_pattern_compression()

        self.classify_helper = TaxClassifyHelper(self.cfg, self.bid_taxonomy_map, args.p_value, self.rate, self.node_height)

    def cleanup(self):
        FileUtils.remove_if_exists(self.tmp_refaln)
        FileUtils.remove_if_exists(self.epa_alignment)
        FileUtils.remove_if_exists(self.hmmprofile)
        FileUtils.remove_if_exists(self.tmpquery)
        FileUtils.remove_if_exists(self.noalign)

    def align_to_refenence(self, noalign, minp = 0.9):
        refaln = self.refjson.get_alignment(fout = self.tmp_refaln)
        fprofile = self.refjson.get_hmm_profile(self.hmmprofile)
        
        # if there is no hmmer profile in json file, build it from scratch          
        if not fprofile:
            hmm = hmmer(self.cfg, refaln)
            fprofile = hmm.build_hmm_profile()
    
        hm = hmmer(config = self.cfg, refalign = refaln , query = self.tmpquery, refprofile = fprofile, discard = noalign, seqs = self.seqs, minp = minp)
        self.epa_alignment = hm.align()

    def merge_alignment(self, query_seqs):
        refaln = self.refjson.get_alignment_list()
        with open(self.epa_alignment, "w") as fout:
            for seq in refaln:
                fout.write(">" + seq[0] + "\n" + seq[1] + "\n")
            for name, seq, comment, sid in query_seqs.iter_entries():
                fout.write(">" + name + "\n" + seq + "\n")


    def checkinput(self, query_fname, minp = 0.9):
        formats = ["fasta", "phylip", "iphylip", "phylip_relaxed", "iphylip_relaxed"]
        for fmt in formats:
            try:
                self.seqs = SeqGroup(sequences=query_fname, format = fmt)
                break
            except:
                print("Guessing input format: not " + fmt)
        if self.seqs == None:
            print("Invalid input file format!")
            print("The supported input formats are fasta and phylip")
            sys.exit()

        if self.ignore_refalign:
            print("Assuming query file contains reference sequences, skipping the alignment step...")
            with open(self.epa_alignment, "w") as fout:
                for name, seq, comment, sid in self.seqs.iter_entries():
                    ref_name = self.REF_PREFIX + name
                    if ref_name in self.refjson.get_sequences_names():
                        seq_name = ref_name
                    else:
                        seq_name = EpacConfig.QUERY_SEQ_PREFIX + name
                    fout.write(">" + seq_name + "\n" + seq + "\n")
            return
            
        # add query seq name prefix to avoid confusion between reference and query sequences
        self.seqs.add_name_prefix(EpacConfig.QUERY_SEQ_PREFIX)
        
        self.seqs.write(format="fasta", outfile=self.tmpquery)
        print("Checking if query sequences are aligned ...")
        entries = self.seqs.get_entries()
        seql = len(entries[0][1])
        aligned = True
        for entri in entries[1:]:
            l = len(entri[1])
            if not seql == l:
                aligned = False
                break
        
        if aligned and len(self.seqs) > 1:
            print("Query sequences are aligned")
            refalnl = self.refjson.get_alignment_length()
            if refalnl == seql:
                print("Merging query alignment with reference alignment")
                self.merge_alignment(self.seqs)
            else:
                print("Merging query alignment with reference alignment using MUSCLE")
                require_muscle()
                refaln = self.refjson.get_alignment(fout = self.tmp_refaln)
                m = muscle(self.cfg)
                self.epa_alignment = m.merge(refaln, self.tmpquery)
        else:
            print("Query sequences are not aligned")
            print("Align query sequences to the reference alignment using HMMER")
            require_hmmer()
            self.align_to_refenence(self.noalign, minp = minp)
        
        print("Running EPA ......")
        print("")


    def print_ranks(self, rks, confs, minlw = 0.0):
        ss = ""
        css = ""
        for i in range(len(rks)):
            conf = confs[i]
            if conf == confs[0] and confs[0] >=0.99:
                conf = 1.0
            if conf >= minlw:
                ss = ss + rks[i] + ";"
                css = css + "{0:.3f}".format(conf) + ";"
            else:
                break
        if ss == "":
            return None
        else:
            return ss[:-1] + "\t" + css[:-1]


    def classify(self, query_fname, fout = None, method = "1", minlw = 0.0, pv = 0.02, minp = 0.9, ptp = False):
        if self.jplace_fname:
            jp = EpaJsonParser(self.jplace_fname)
        else:        
            self.checkinput(query_fname, minp)
            raxml = RaxmlWrapper(config)
            reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre")
            self.refjson.get_raxml_readable_tree(reftree_fname)
            optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
            self.refjson.get_binary_model(optmod_fname)
            job_name = self.cfg.subst_name("epa_%NAME%")

            reftree_str = self.refjson.get_raxml_readable_tree()
            reftree = Tree(reftree_str)

            self.reftree_size = len(reftree.get_leaves())

            # IMPORTANT: set EPA heuristic rate based on tree size!                
            self.cfg.resolve_auto_settings(self.reftree_size)
            # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file        
            if self.cfg.epa_load_optmod:
                self.cfg.raxml_model = self.refjson.get_ratehet_model()

            reduced_align_fname = raxml.reduce_alignment(self.epa_alignment)

            jp = raxml.run_epa(job_name, reduced_align_fname, reftree_fname, optmod_fname)
        
        placements = jp.get_placement()
        
        if fout:
            fo = open(fout, "w")
        else:
            fo = None
        
        output2 = ""
        for place in placements:
            output = None
            taxon_name = place["n"][0]
            origin_taxon_name = EpacConfig.strip_query_prefix(taxon_name)
            edges = place["p"]
#            edges = self.erlang_filter(edges, p = pv)
            if len(edges) > 0:
                ranks, lws = self.classify_helper.classify_seq(edges, method, minlw)
                
                isnovo = self.novelty_check(place_edge = str(edges[0][0]), ranks =ranks, lws = lws, minlw = minlw)
                rankout = self.print_ranks(ranks, lws, minlw)
                
                if rankout == None:
                    output2 = output2 + origin_taxon_name+ "\t\t\t?\n"
                else:
                    output = "%s\t%s\t" % (origin_taxon_name, self.print_ranks(ranks, lws, minlw))
                    if isnovo: 
                        output += "*"
                    else:
                        output +="o"
                    if self.cfg.verbose:
                        print(output) 
                    if fo:
                        fo.write(output + "\n")
            else:
                output2 = output2 + origin_taxon_name+ "\t\t\t?\n"
        
        if os.path.exists(self.noalign):
            with open(self.noalign) as fnoa:
                lines = fnoa.readlines()
                for line in lines:
                    taxon_name = line.strip()[1:]
                    origin_taxon_name = EpacConfig.strip_query_prefix(taxon_name)
                    output = "%s\t\t\t?" % origin_taxon_name
                    if self.cfg.verbose:
                        print(output)
                    if fo:
                        fo.write(output + "\n")
        
        if self.cfg.verbose:
            print(output2)
        
        if fo:
            fo.write(output2)
            fo.close()

        #############################################
        #
        # EPA-PTP species delimitation
        #
        #############################################
        if ptp:
            full_aln = SeqGroup(self.epa_alignment)
            species_list = epa_2_ptp(epa_jp = jp, ref_jp = self.refjson, full_alignment = full_aln, min_lw = 0.5, debug = self.cfg.debug)
            
            if self.cfg.verbose:
                print "Species clusters:"

            if fout:
                fo2 = open(fout+".species", "w")
            else:
                fo2 = None

            for sp_cluster in species_list:
                translated_taxa = []
                for taxon in sp_cluster:
                    origin_taxon_name = EpacConfig.strip_query_prefix(taxon)
                    translated_taxa.append(origin_taxon_name)
                s = ",".join(translated_taxa)
                if fo2:
                    fo2.write(s + "\n")
                if self.cfg.verbose:
                    print s

            if fo2:
                fo2.close()
        #############################################
        
        if not self.jplace_fname:
            if not self.cfg.debug:
                raxml.cleanup(job_name)
                FileUtils.remove_if_exists(reduced_align_fname)
                FileUtils.remove_if_exists(reftree_fname)
                FileUtils.remove_if_exists(optmod_fname)

    def novelty_check(self, place_edge, ranks, lws, minlw):
        """If the taxonomic assignment is not assigned to the genus level, 
        we need to check if it is due to the incomplete reference taxonomy or 
        it is likely to be something new:
        
        1. If the final ranks are assinged because of lw cut, that means with samller lw
        the ranks can be further assinged to lowers. This indicate the undetermined ranks 
        in the assignment is not due to the incomplete reference taxonomy, so the query 
        sequence is likely to be something new.
        
        2. Otherwise We check all leaf nodes' immediate lower rank below this ml placement point, 
        if they are not empty, output all ranks and indicate this could be novelty.
        """
        
        lowrank = 0
        for i in range(len(ranks)):
            if i < 6:
                """above genus level"""
                rk = ranks[i]
                lw = lws[i]
                if rk == "-":
                    break
                else:
                    lowrank = lowrank + 1
                    if lw >=0 and lw < minlw:
                        return True
        
        if lowrank >= 5 and not ranks[lowrank] == "-":
            return False
        else:
            placenode = self.reftree.search_nodes(B = place_edge)[0]
            if placenode.is_leaf():
                return False
            else:
                leafnodes = placenode.get_leaves()
                flag = True
                for leaf in leafnodes:
                    br_num = leaf.B
                    branks = self.bid_taxonomy_map[br_num]
                    if branks[lowrank] == "-":
                        flag = False
                        break
                        
                return flag
Example #18
0
class RefTreeBuilder:
    def __init__(self, config): 
        self.cfg = config
        self.mfresolv_job_name = self.cfg.subst_name("mfresolv_%NAME%")
        self.epalbl_job_name = self.cfg.subst_name("epalbl_%NAME%")
        self.optmod_job_name = self.cfg.subst_name("optmod_%NAME%")
        self.raxml_wrapper = RaxmlWrapper(config)
        
        self.outgr_fname = self.cfg.tmp_fname("%NAME%_outgr.tre")
        self.reftree_mfu_fname = self.cfg.tmp_fname("%NAME%_mfu.tre")
        self.reftree_bfu_fname = self.cfg.tmp_fname("%NAME%_bfu.tre")
        self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
        self.lblalign_fname = self.cfg.tmp_fname("%NAME%_lblq.fa")
        self.reftree_lbl_fname = self.cfg.tmp_fname("%NAME%_lbl.tre")
        self.reftree_tax_fname = self.cfg.tmp_fname("%NAME%_tax.tre")
        self.brmap_fname = self.cfg.tmp_fname("%NAME%_map.txt")

    def load_alignment(self):
        in_file = self.cfg.align_fname
        self.input_seqs = None
        formats = ["fasta", "phylip_relaxed", "iphylip_relaxed", "phylip", "iphylip"]
        for fmt in formats:
            try:
                self.input_seqs = SeqGroup(sequences=in_file, format = fmt)
                break
            except:
                self.cfg.log.debug("Guessing input format: not " + fmt)
        if self.input_seqs == None:
            self.cfg.exit_user_error("Invalid input file format: %s\nThe supported input formats are fasta and phylip" % in_file)
            
    def validate_taxonomy(self):
        self.input_validator = InputValidator(self.cfg, self.taxonomy, self.input_seqs)
        self.input_validator.validate()
        
    def build_multif_tree(self):
        c = self.cfg
        
        tb = TaxTreeBuilder(c, self.taxonomy)
        (t, ids) = tb.build(c.reftree_min_rank, c.reftree_max_seqs_per_leaf, c.reftree_clades_to_include, c.reftree_clades_to_ignore)
        self.reftree_ids = frozenset(ids)
        self.reftree_size = len(ids)
        self.reftree_multif = t

        # IMPORTANT: select GAMMA or CAT model based on tree size!                
        self.cfg.resolve_auto_settings(self.reftree_size)

        if self.cfg.debug:
            refseq_fname = self.cfg.tmp_fname("%NAME%_seq_ids.txt")
            # list of sequence ids which comprise the reference tree
            with open(refseq_fname, "w") as f:
                for sid in ids:
                    f.write("%s\n" % sid)

            # original tree with taxonomic ranks as internal node labels
            reftax_fname = self.cfg.tmp_fname("%NAME%_mfu_tax.tre")
            t.write(outfile=reftax_fname, format=8)
        #    t.show()

    def export_ref_alignment(self):
        """This function transforms the input alignment in the following way:
           1. Filter out sequences which are not part of the reference tree
           2. Add sequence name prefix (r_)"""
        
        self.refalign_fname = self.cfg.tmp_fname("%NAME%_matrix.afa")
        with open(self.refalign_fname, "w") as fout:
            for name, seq, comment, sid in self.input_seqs.iter_entries():
                seq_name = EpacConfig.REF_SEQ_PREFIX + name
                if seq_name in self.input_validator.corr_seqid:
                  seq_name = self.input_validator.corr_seqid[seq_name]
                if seq_name in self.reftree_ids:
                    fout.write(">" + seq_name + "\n" + seq + "\n")

        # we do not need the original alignment anymore, so free its memory
        self.input_seqs = None

    def export_ref_taxonomy(self):
        self.taxonomy_map = {}
        
        for sid, ranks in self.taxonomy.iteritems():
            if sid in self.reftree_ids:
                self.taxonomy_map[sid] = ranks
            
        if self.cfg.debug:
            tax_fname = self.cfg.tmp_fname("%NAME%_tax.txt")
            with open(tax_fname, "w") as fout:
                for sid, ranks in self.taxonomy_map.iteritems():
                    ranks_str = self.taxonomy.seq_lineage_str(sid) 
                    fout.write(sid + "\t" + ranks_str + "\n")   

    def save_rooting(self):
        rt = self.reftree_multif

        tax_map = self.taxonomy.get_map()
        self.taxtree_helper = TaxTreeHelper(self.cfg, tax_map)
        self.taxtree_helper.set_mf_rooted_tree(rt)
        outgr = self.taxtree_helper.get_outgroup()
        outgr_size = len(outgr.get_leaves())
        outgr.write(outfile=self.outgr_fname, format=9)
        self.reftree_outgroup = outgr
        self.cfg.log.debug("Outgroup for rooting was saved to: %s, outgroup size: %d", self.outgr_fname, outgr_size)
            
        # remove unifurcation at the root
        if len(rt.children) == 1:
            rt = rt.children[0]
        
        # now we can safely unroot the tree and remove internal node labels to make it suitable for raxml
        rt.write(outfile=self.reftree_mfu_fname, format=9)

    # RAxML call to convert multifurcating tree to the strictly bifurcating one
    def resolve_multif(self):
        self.cfg.log.debug("\nReducing the alignment: \n")
        self.reduced_refalign_fname = self.raxml_wrapper.reduce_alignment(self.refalign_fname)
        
        self.cfg.log.debug("\nConstrained ML inference: \n")
        raxml_params = ["-s", self.reduced_refalign_fname, "-g", self.reftree_mfu_fname, "--no-seq-check", "-N", str(self.cfg.rep_num)] 
        if self.cfg.mfresolv_method  == "fast":
            raxml_params += ["-D"]
        elif self.cfg.mfresolv_method  == "ultrafast":
            raxml_params += ["-f", "e"]
        if self.cfg.restart and self.raxml_wrapper.result_exists(self.mfresolv_job_name):
            self.invocation_raxml_multif = self.raxml_wrapper.get_invocation_str(self.mfresolv_job_name)
            self.cfg.log.debug("\nUsing existing ML tree found in: %s\n", self.raxml_wrapper.result_fname(self.mfresolv_job_name))
        else:
            self.invocation_raxml_multif = self.raxml_wrapper.run(self.mfresolv_job_name, raxml_params)
#            self.invocation_raxml_multif = self.raxml_wrapper.run_multiple(self.mfresolv_job_name, raxml_params, self.cfg.rep_num)
            if self.cfg.mfresolv_method  == "ultrafast":
              self.raxml_wrapper.copy_result_tree(self.mfresolv_job_name, self.raxml_wrapper.besttree_fname(self.mfresolv_job_name))
              
        if self.raxml_wrapper.besttree_exists(self.mfresolv_job_name):        
            if not self.cfg.reopt_model:
                self.raxml_wrapper.copy_best_tree(self.mfresolv_job_name, self.reftree_bfu_fname)
                self.raxml_wrapper.copy_optmod_params(self.mfresolv_job_name, self.optmod_fname)
                self.invocation_raxml_optmod = ""
                job_name = self.mfresolv_job_name
            else:
                bfu_fname = self.raxml_wrapper.besttree_fname(self.mfresolv_job_name)
                job_name = self.optmod_job_name

                # RAxML call to optimize model parameters and write them down to the binary model file
                self.cfg.log.debug("\nOptimizing model parameters: \n")
                raxml_params = ["-f", "e", "-s", self.reduced_refalign_fname, "-t", bfu_fname, "--no-seq-check"]
                if self.cfg.raxml_model.startswith("GTRCAT") and not self.cfg.compress_patterns:
                    raxml_params +=  ["-H"]
                if self.cfg.restart and self.raxml_wrapper.result_exists(self.optmod_job_name):
                    self.invocation_raxml_optmod = self.raxml_wrapper.get_invocation_str(self.optmod_job_name)
                    self.cfg.log.debug("\nUsing existing optimized tree and parameters found in: %s\n", self.raxml_wrapper.result_fname(self.optmod_job_name))
                else:
                    self.invocation_raxml_optmod = self.raxml_wrapper.run(self.optmod_job_name, raxml_params)
                if self.raxml_wrapper.result_exists(self.optmod_job_name):
                    self.raxml_wrapper.copy_result_tree(self.optmod_job_name, self.reftree_bfu_fname)
                    self.raxml_wrapper.copy_optmod_params(self.optmod_job_name, self.optmod_fname)
                else:
                    errmsg = "RAxML run failed (model optimization), please examine the log for details: %s" \
                            % self.raxml_wrapper.make_raxml_fname("output", self.optmod_job_name)
                    self.cfg.exit_fatal_error(errmsg)
                    
            if self.cfg.raxml_model.startswith("GTRCAT"):
              mod_name = "CAT"
            else:
              mod_name = "GAMMA" 
            self.reftree_loglh = self.raxml_wrapper.get_tree_lh(job_name, mod_name)
            self.cfg.log.debug("\n%s-based logLH of the reference tree: %f\n" % (mod_name, self.reftree_loglh))

        else:
            errmsg = "RAxML run failed (mutlifurcation resolution), please examine the log for details: %s" \
                    % self.raxml_wrapper.make_raxml_fname("output", self.mfresolv_job_name)
            self.cfg.exit_fatal_error(errmsg)
            
    def load_reduced_refalign(self):
        formats = ["fasta", "phylip_relaxed"]
        for fmt in formats:
            try:
                self.reduced_refalign_seqs = SeqGroup(sequences=self.reduced_refalign_fname, format = fmt)
                break
            except:
                pass
        if self.reduced_refalign_seqs == None:
            errmsg = "FATAL ERROR: Invalid input file format in %s! (load_reduced_refalign)" % self.reduced_refalign_fname
            self.cfg.exit_fatal_error(errmsg)
    
    # dummy EPA run to label the branches of the reference tree, which we need to build a mapping to tax ranks    
    def epa_branch_labeling(self):
        # create alignment with dummy query seq
        self.refalign_width = len(self.reduced_refalign_seqs.get_seqbyid(0))
        self.reduced_refalign_seqs.write(format="fasta", outfile=self.lblalign_fname)
        
        with open(self.lblalign_fname, "a") as fout:
            fout.write(">" + "DUMMY131313" + "\n")        
            fout.write("A"*self.refalign_width + "\n")        
        
        # TODO always load model regardless of the config file settings?
        epa_result = self.raxml_wrapper.run_epa(self.epalbl_job_name, self.lblalign_fname, self.reftree_bfu_fname, self.optmod_fname, mode="epa_mp")
        self.reftree_lbl_str = epa_result.get_std_newick_tree()
        self.raxml_version = epa_result.get_raxml_version()
        self.invocation_raxml_epalbl = epa_result.get_raxml_invocation()

        if not self.raxml_wrapper.epa_result_exists(self.epalbl_job_name):        
            errmsg = "RAxML EPA run failed, please examine the log for details: %s" \
                    % self.raxml_wrapper.make_raxml_fname("output", self.epalbl_job_name)
            self.cfg.exit_fatal_error(errmsg)

    def epa_post_process(self):
        lbl_tree = Tree(self.reftree_lbl_str)
        self.taxtree_helper.set_bf_unrooted_tree(lbl_tree)
        self.reftree_tax = self.taxtree_helper.get_tax_tree()
        self.bid_ranks_map = self.taxtree_helper.get_bid_taxonomy_map()
        
        if self.cfg.debug:
            self.reftree_tax.write(outfile=self.reftree_tax_fname, format=3)
            with open(self.reftree_lbl_fname, "w") as outf:
                outf.write(self.reftree_lbl_str)
            with open(self.brmap_fname, "w") as outf:
                for bid, br_rec in self.bid_ranks_map.iteritems():
                    outf.write("%s\t%s\t%d\t%f\n" % (bid, br_rec[0], br_rec[1], br_rec[2]))

    def calc_node_heights(self):
        """Calculate node heights on the reference tree (used to define branch-length cutoff during classification step)
           Algorithm is as follows:
           Tip node or node resolved to Species level: height = 1 
           Inner node resolved to Genus or above:      height = min(left_height, right_height) + 1 
         """
        nh_map = {}
        dummy_added = False
        for node in self.reftree_tax.traverse("postorder"):
            if not node.is_root():
                if not hasattr(node, "B"):                
                    # In a rooted tree, there is always one more node/branch than in unrooted one
                    # That's why one branch will be always not EPA-labelled after the rooting
                    if not dummy_added: 
                        node.B = "DDD"
                        dummy_added = True
                        species_rank = Taxonomy.EMPTY_RANK
                    else:
                        errmsg = "FATAL ERROR: More than one tree branch without EPA label (calc_node_heights)"
                        self.cfg.exit_fatal_error(errmsg)
                else:
                    species_rank = self.bid_ranks_map[node.B][-1]
                bid = node.B
                if node.is_leaf() or species_rank != Taxonomy.EMPTY_RANK:
                    nh_map[bid] = 1
                else:
                    lchild = node.children[0]
                    rchild = node.children[1]
                    nh_map[bid] = min(nh_map[lchild.B], nh_map[rchild.B]) + 1

        # remove heights for dummy nodes, since there won't be any placements on them
        if dummy_added:
            del nh_map["DDD"]
            
        self.node_height_map = nh_map

    def __get_all_rank_names(self, root):
        rnames = set([])
        for node in root.traverse("postorder"):
            ranks = node.ranks
            for rk in ranks:
                rnames.add(rk)
        return rnames

    def mono_index(self):
        """This method will calculate monophyly index by looking at the left and right hand side of the tree"""
        children = self.reftree_tax.children
        if len(children) == 1:
            while len(children) == 1:
                children = children[0].children 
        if len(children) == 2:
            left = children[0]
            right =children[1]
            lset = self.__get_all_rank_names(left)
            rset = self.__get_all_rank_names(right)
            iset = lset & rset
            return iset
        else:
            print("Error: input tree not birfurcating")
            return set([])

    def build_hmm_profile(self, json_builder):
        print "Building the HMMER profile...\n"

        # this stupid workaround is needed because RAxML outputs the reduced
        # alignment in relaxed PHYLIP format, which is not supported by HMMER
        refalign_fasta = self.cfg.tmp_fname("%NAME%_ref_reduced.fa")
        self.reduced_refalign_seqs.write(outfile=refalign_fasta)

        hmm = hmmer(self.cfg, refalign_fasta)
        fprofile = hmm.build_hmm_profile()

        json_builder.set_hmm_profile(fprofile)
        
    def write_json(self):
        jw = RefJsonBuilder()

        jw.set_branch_tax_map(self.bid_ranks_map)
        jw.set_tree(self.reftree_lbl_str)
        jw.set_outgroup(self.reftree_outgroup)
        jw.set_ratehet_model(self.cfg.raxml_model)
        jw.set_tax_tree(self.reftree_multif)
        jw.set_pattern_compression(self.cfg.compress_patterns)
        jw.set_taxcode(self.cfg.taxcode_name)
        
        jw.set_merged_ranks_map(self.input_validator.merged_ranks)
        corr_ranks_reverse = dict((reversed(item) for item in self.input_validator.corr_ranks.items()))
        jw.set_corr_ranks_map(corr_ranks_reverse)
        corr_seqid_reverse = dict((reversed(item) for item in self.input_validator.corr_seqid.items()))
        jw.set_corr_seqid_map(corr_seqid_reverse)

        mdata = { "ref_tree_size": self.reftree_size, 
                  "ref_alignment_width": self.refalign_width,
                  "raxml_version": self.raxml_version,
                  "timestamp": str(datetime.datetime.now()),
                  "invocation_epac": self.invocation_epac,
                  "invocation_raxml_multif": self.invocation_raxml_multif,
                  "invocation_raxml_optmod": self.invocation_raxml_optmod,
                  "invocation_raxml_epalbl": self.invocation_raxml_epalbl,
                  "reftree_loglh": self.reftree_loglh
                }
        jw.set_metadata(mdata)

        seqs = self.reduced_refalign_seqs.get_entries()    
        jw.set_sequences(seqs)
        
        if not self.cfg.no_hmmer:
            self.build_hmm_profile(jw)

        orig_tax = self.taxonomy_map
        jw.set_origin_taxonomy(orig_tax)
        
        self.cfg.log.debug("Calculating the speciation rate...\n")
        tp = tree_param(tree = self.reftree_lbl_str, origin_taxonomy = orig_tax)
        jw.set_rate(tp.get_speciation_rate_fast())
        jw.set_nodes_height(self.node_height_map)
        
        jw.set_binary_model(self.optmod_fname)
        
        self.cfg.log.debug("Writing down the reference file...\n")
        jw.dump(self.cfg.refjson_fname)

    # top-level function to build a reference tree    
    def build_ref_tree(self):
        self.cfg.log.info("=> Loading taxonomy from file: %s ...\n" , self.cfg.taxonomy_fname)
        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX, tax_fname=self.cfg.taxonomy_fname)
        self.cfg.log.info("==> Loading reference alignment from file: %s ...\n" , self.cfg.align_fname)
        self.load_alignment()
        self.cfg.log.info("===> Validating taxonomy and alignment ...\n")
        self.validate_taxonomy()
        self.cfg.log.info("====> Building a multifurcating tree from taxonomy with %d seqs ...\n" , self.taxonomy.seq_count())
        self.build_multif_tree()
        self.cfg.log.info("=====> Building the reference alignment ...\n")
        self.export_ref_alignment()
        self.export_ref_taxonomy()
        self.cfg.log.info("======> Saving the outgroup for later re-rooting ...\n")
        self.save_rooting()
        self.cfg.log.info("=======> Resolving multifurcation: choosing the best topology from %d independent RAxML runs ...\n" % self.cfg.rep_num)
        self.resolve_multif()
        self.load_reduced_refalign()
        self.cfg.log.info("========> Calling RAxML-EPA to obtain branch labels ...\n")
        self.epa_branch_labeling()
        self.cfg.log.info("=========> Post-processing the EPA tree (re-rooting, taxonomic labeling etc.) ...\n")
        self.epa_post_process()
        self.calc_node_heights()
        
        self.cfg.log.debug("\n==========> Checking branch labels ...")
        self.cfg.log.debug("shared rank names before training: %s", repr(self.taxonomy.get_common_ranks()))
        self.cfg.log.debug("shared rank names after  training: %s\n", repr(self.mono_index()))
        
        self.cfg.log.info("==========> Saving the reference JSON file: %s\n" % self.cfg.refjson_fname)
        self.write_json()
Example #19
0
class RefTreeBuilder:
    def __init__(self, config): 
        self.cfg = config
        self.mfresolv_job_name = self.cfg.subst_name("mfresolv_%NAME%")
        self.epalbl_job_name = self.cfg.subst_name("epalbl_%NAME%")
        self.optmod_job_name = self.cfg.subst_name("optmod_%NAME%")
        self.raxml_wrapper = RaxmlWrapper(config)
        
        self.outgr_fname = self.cfg.tmp_fname("%NAME%_outgr.tre")
        self.reftree_mfu_fname = self.cfg.tmp_fname("%NAME%_mfu.tre")
        self.reftree_bfu_fname = self.cfg.tmp_fname("%NAME%_bfu.tre")
        self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
        self.lblalign_fname = self.cfg.tmp_fname("%NAME%_lblq.fa")
        self.reftree_lbl_fname = self.cfg.tmp_fname("%NAME%_lbl.tre")
        self.reftree_tax_fname = self.cfg.tmp_fname("%NAME%_tax.tre")
        self.brmap_fname = self.cfg.tmp_fname("%NAME%_map.txt")

    def validate_taxonomy(self):
        # make sure we don't taxonomy "irregularities" (more than 7 ranks or missing ranks in the middle)
        action = self.cfg.wrong_rank_count
        if action != "ignore":
            autofix = action == "autofix"
            errs = self.taxonomy.check_for_disbalance(autofix)
            if len(errs) > 0:
                if action == "autofix":
                    print "WARNING: %d sequences with invalid annotation (missing/redundant ranks) found and were fixed as follows:\n" % len(errs)
                    for err in errs:
                        print "Original:   %s\t%s"   % (err[0], err[1])
                        print "Fixed as:   %s\t%s\n" % (err[0], err[2])
                elif action == "skip":
                    print "WARNING: Following %d sequences with invalid annotation (missing/redundant ranks) were skipped:\n" % len(errs)
                    for err in errs:
                        self.taxonomy.remove_seq(err[0])
                        print "%s\t%s" % err
                else:  # abort
                    print "ERROR: %d sequences with invalid annotation (missing/redundant ranks) found:\n" % len(errs)
                    for err in errs:
                        print "%s\t%s" % err
                    print "\nPlease fix them manually (add/remove ranks) and run the pipeline again (or use -wrong-rank-count autofix option)"
                    print "NOTE: Only standard 7-level taxonomies are supported at the moment. Although missing trailing ranks (e.g. species) are allowed,"
                    print "missing intermediate ranks (e.g. family) or sublevels (e.g. suborder) are not!\n"
                    sys.exit()

        # check for duplicate rank names
        action = self.cfg.dup_rank_names
        if action != "ignore":
            autofix = action == "autofix"
            dups = self.taxonomy.check_for_duplicates(autofix)
            if len(dups) > 0:
                if action == "autofix":
                    print "WARNING: %d sequences with duplicate rank names found and were renamed as follows:\n" % len(dups)
                    for dup in dups:
                        print "Original:    %s\t%s"   %  (dup[0], dup[1])
                        print "Duplicate:   %s\t%s"   %  (dup[2], dup[3])
                        print "Renamed to:  %s\t%s\n" %  (dup[2], dup[4])
                elif action == "skip":
                    print "WARNING: Following %d sequences with duplicate rank names were skipped:\n" % len(dups)
                    for dup in dups:
                        self.taxonomy.remove_seq(dup[2])
                        print "%s\t%s\n" % (dup[2], dup[3])
                else:  # abort
                    print "ERROR: %d sequences with duplicate rank names found:\n" % len(dups)
                    for dup in dups:
                        print "%s\t%s\n%s\t%s\n" % dup
                    print "Please fix (rename) them and run the pipeline again (or use -dup-rank-names autofix option)" 
                    sys.exit()
        
        # check for invalid characters in rank names
        self.taxonomy.normalize_rank_names()
        
        self.taxonomy.close_taxonomy_gaps()

    def build_multif_tree(self):
        c = self.cfg
        
        tb = TaxTreeBuilder(c, self.taxonomy)
        (t, ids) = tb.build(c.reftree_min_rank, c.reftree_max_seqs_per_leaf, c.reftree_clades_to_include, c.reftree_clades_to_ignore)
        self.reftree_ids = frozenset(ids)
        self.reftree_size = len(ids)
        self.reftree_multif = t

        # IMPORTANT: select GAMMA or CAT model based on tree size!                
        self.cfg.resolve_auto_settings(self.reftree_size)

        if self.cfg.debug:
            refseq_fname = self.cfg.tmp_fname("%NAME%_seq_ids.txt")
            # list of sequence ids which comprise the reference tree
            with open(refseq_fname, "w") as f:
                for sid in ids:
                    f.write("%s\n" % sid)

            # original tree with taxonomic ranks as internal node labels
            reftax_fname = self.cfg.tmp_fname("%NAME%_mfu_tax.tre")
            t.write(outfile=reftax_fname, format=8)
        #    t.show()

    def export_ref_alignment(self):
        """This function transforms the input alignment in the following way:
           1. Filter out sequences which are not part of the reference tree
           2. Add sequence name prefix (r_)"""
        in_file = self.cfg.align_fname
        ref_seqs = None
        formats = ["fasta", "phylip", "iphylip", "phylip_relaxed", "iphylip_relaxed"]
        for fmt in formats:
            try:
                ref_seqs = SeqGroup(sequences=in_file, format = fmt)
                break
            except:
                if self.cfg.debug:
                    print("Guessing input format: not " + fmt)
        if ref_seqs == None:
            print("Invalid input file format: %s" % in_file)
            print("The supported input formats are fasta and phylip")
            sys.exit()

        self.refalign_fname = self.cfg.tmp_fname("%NAME%_matrix.afa")
        with open(self.refalign_fname, "w") as fout:
            for name, seq, comment, sid in ref_seqs.iter_entries():
                seq_name = EpacConfig.REF_SEQ_PREFIX + name
                if seq_name in self.reftree_ids:
                    fout.write(">" + seq_name + "\n" + seq + "\n")

    def export_ref_taxonomy(self):
        self.taxonomy_map = {}
        
        for sid, ranks in self.taxonomy.iteritems():
            if sid in self.reftree_ids:
                self.taxonomy_map[sid] = ranks
            
        if self.cfg.debug:
            tax_fname = self.cfg.tmp_fname("%NAME%_tax.txt")
            with open(tax_fname, "w") as fout:
                for sid, ranks in self.taxonomy_map.iteritems():
                    ranks_str = self.taxonomy.lineage_str(sid, True) 
                    fout.write(sid + "\t" + ranks_str + "\n")   

    def save_rooting(self):
        rt = self.reftree_multif

        tax_map = self.taxonomy.get_map()
        self.taxtree_helper = TaxTreeHelper(tax_map, self.cfg)
        self.taxtree_helper.set_mf_rooted_tree(rt)
        outgr = self.taxtree_helper.get_outgroup()
        outgr_size = len(outgr.get_leaves())
        outgr.write(outfile=self.outgr_fname, format=9)
        self.reftree_outgroup = outgr
        if self.cfg.verbose:
            print "Outgroup for rooting was saved to: %s, outgroup size: %d" % (self.outgr_fname, outgr_size)
            
        # remove unifurcation at the root
        if len(rt.children) == 1:
            rt = rt.children[0]
        
        # now we can safely unroot the tree and remove internal node labels to make it suitable for raxml
        rt.write(outfile=self.reftree_mfu_fname, format=9)

    # RAxML call to convert multifurcating tree to the strictly bifurcating one
    def resolve_multif(self):
        print "\nReducing the alignment: \n"
        self.reduced_refalign_fname = self.raxml_wrapper.reduce_alignment(self.refalign_fname)
        
        print "\nResolving multifurcation: \n"
        raxml_params = ["-s", self.reduced_refalign_fname, "-g", self.reftree_mfu_fname, "-F", "--no-seq-check"]
        if self.cfg.mfresolv_method  == "fast":
            raxml_params += ["-D"]
        elif self.cfg.mfresolv_method  == "ultrafast":
            raxml_params += ["-f", "e"]
        self.invocation_raxml_multif = self.raxml_wrapper.run(self.mfresolv_job_name, raxml_params)
        if self.raxml_wrapper.result_exists(self.mfresolv_job_name):        
#            self.raxml_wrapper.copy_result_tree(self.mfresolv_job_name, self.reftree_bfu_fname)
#            self.raxml_wrapper.copy_optmod_params(self.mfresolv_job_name, self.optmod_fname)

            bfu_fname = self.raxml_wrapper.result_fname(self.mfresolv_job_name)

            # RAxML call to optimize model parameters and write them down to the binary model file
            print "\nOptimizing model parameters: \n"
            raxml_params = ["-f", "e", "-s", self.reduced_refalign_fname, "-t", bfu_fname, "--no-seq-check"]
            if self.cfg.raxml_model == "GTRCAT" and not self.cfg.compress_patterns:
                raxml_params +=  ["-H"]
            self.invocation_raxml_optmod = self.raxml_wrapper.run(self.optmod_job_name, raxml_params)
            if self.raxml_wrapper.result_exists(self.optmod_job_name):
                self.raxml_wrapper.copy_result_tree(self.optmod_job_name, self.reftree_bfu_fname)
                self.raxml_wrapper.copy_optmod_params(self.optmod_job_name, self.optmod_fname)
                if not self.cfg.debug:
                    self.raxml_wrapper.cleanup(self.optmod_job_name)
            else:
                print "RAxML run failed (model optimization), please examine the log for details: %s" \
                        % self.raxml_wrapper.make_raxml_fname("output", self.optmod_job_name)
                sys.exit()  

            if not self.cfg.debug:
                self.raxml_wrapper.cleanup(self.mfresolv_job_name)
        else:
            print "RAxML run failed (mutlifurcation resolution), please examine the log for details: %s" \
                    % self.raxml_wrapper.make_raxml_fname("output", self.mfresolv_job_name)
            sys.exit()  
            
    def load_reduced_refalign(self):
        formats = ["fasta", "phylip_relaxed"]
        for fmt in formats:
            try:
                self.reduced_refalign_seqs = SeqGroup(sequences=self.reduced_refalign_fname, format = fmt)
                break
            except:
                pass
        if self.reduced_refalign_seqs == None:
            print("FATAL ERROR: Invalid input file format in %s! (load_reduced_refalign)" % self.reduced_refalign_fname)
            sys.exit()
    
    # dummy EPA run to label the branches of the reference tree, which we need to build a mapping to tax ranks    
    def epa_branch_labeling(self):
        # create alignment with dummy query seq
        self.refalign_width = len(self.reduced_refalign_seqs.get_seqbyid(0))
        self.reduced_refalign_seqs.write(format="fasta", outfile=self.lblalign_fname)
        
        with open(self.lblalign_fname, "a") as fout:
            fout.write(">" + "DUMMY131313" + "\n")        
            fout.write("A"*self.refalign_width + "\n")        
        
        epa_result = self.raxml_wrapper.run_epa(self.epalbl_job_name, self.lblalign_fname, self.reftree_bfu_fname, self.optmod_fname)
        self.reftree_lbl_str = epa_result.get_std_newick_tree()
        self.raxml_version = epa_result.get_raxml_version()
        self.invocation_raxml_epalbl = epa_result.get_raxml_invocation()

        if self.raxml_wrapper.epa_result_exists(self.epalbl_job_name):        
            if not self.cfg.debug:
                self.raxml_wrapper.cleanup(self.epalbl_job_name, True)
        else:
            print "RAxML EPA run failed, please examine the log for details: %s" \
                    % self.raxml_wrapper.make_raxml_fname("output", self.epalbl_job_name)
            sys.exit()        

    def epa_post_process(self):
        lbl_tree = Tree(self.reftree_lbl_str)
        self.taxtree_helper.set_bf_unrooted_tree(lbl_tree)
        self.reftree_tax = self.taxtree_helper.get_tax_tree()
        self.bid_ranks_map = self.taxtree_helper.get_bid_taxonomy_map()
        
        if self.cfg.debug:
            self.reftree_tax.write(outfile=self.reftree_lbl_fname, format=5)
            self.reftree_tax.write(outfile=self.reftree_tax_fname, format=3)

    def build_branch_rank_map(self):
        self.bid_ranks_map = {}
        for node in self.reftree_tax.traverse("postorder"):
            if not node.is_root() and hasattr(node, "B"):                
                parent = node.up                
                self.bid_ranks_map[node.B] = parent.ranks
#                print "%s => %s" % (node.B, parent.ranks)
            elif self.cfg.verbose:
                print "INFO: EPA branch label missing, mapping to taxon skipped (%s)" % node.name
    
    def write_branch_rank_map(self):
        with open(self.brmap_fname, "w") as fbrmap:    
            for node in self.reftree_tax.traverse("postorder"):
                if not node.is_root() and hasattr(node, "B"):                
                    fbrmap.write(node.B + "\t" + ";".join(self.bid_ranks_map[node.B]) + "\n")
    
    def calc_node_heights(self):
        """Calculate node heights on the reference tree (used to define branch-length cutoff during classification step)
           Algorithm is as follows:
           Tip node or node resolved to Species level: height = 1 
           Inner node resolved to Genus or above:      height = min(left_height, right_height) + 1 
         """
        nh_map = {}
        dummy_added = False
        for node in self.reftree_tax.traverse("postorder"):
            if not node.is_root():
                if not hasattr(node, "B"):                
                    # In a rooted tree, there is always one more node/branch than in unrooted one
                    # That's why one branch will be always not EPA-labelled after the rooting
                    if not dummy_added: 
                        node.B = "DDD"
                        dummy_added = True
                        species_rank = Taxonomy.EMPTY_RANK
                    else:
                        print "FATAL ERROR: More than one tree branch without EPA label (calc_node_heights)"
                        sys.exit()
                else:
                    species_rank = self.bid_ranks_map[node.B][6]
                bid = node.B
                if node.is_leaf() or species_rank != Taxonomy.EMPTY_RANK:
                    nh_map[bid] = 1
                else:
                    lchild = node.children[0]
                    rchild = node.children[1]
                    nh_map[bid] = min(nh_map[lchild.B], nh_map[rchild.B]) + 1

        # remove heights for dummy nodes, since there won't be any placements on them
        if dummy_added:
            del nh_map["DDD"]
            
        self.node_height_map = nh_map

    def __get_all_rank_names(self, root):
        rnames = set([])
        for node in root.traverse("postorder"):
            ranks = node.ranks
            for rk in ranks:
                rnames.add(rk)
        return rnames

    def mono_index(self):
        """This method will calculate monophyly index by looking at the left and right hand side of the tree"""
        children = self.reftree_tax.children
        if len(children) == 1:
            while len(children) == 1:
                children = children[0].children 
        if len(children) == 2:
            left = children[0]
            right =children[1]
            lset = self.__get_all_rank_names(left)
            rset = self.__get_all_rank_names(right)
            iset = lset & rset
            return iset
        else:
            print("Error: input tree not birfurcating")
            return set([])

    def build_hmm_profile(self, json_builder):
        print "Building the HMMER profile...\n"

        # this stupid workaround is needed because RAxML outputs the reduced
        # alignment in relaxed PHYLIP format, which is not supported by HMMER
        refalign_fasta = self.cfg.tmp_fname("%NAME%_ref_reduced.fa")
        self.reduced_refalign_seqs.write(outfile=refalign_fasta)

        hmm = hmmer(self.cfg, refalign_fasta)
        fprofile = hmm.build_hmm_profile()

        json_builder.set_hmm_profile(fprofile)
        
        if not self.cfg.debug:
            FileUtils.remove_if_exists(refalign_fasta)
            FileUtils.remove_if_exists(fprofile)

    def write_json(self):
        jw = RefJsonBuilder()

        jw.set_taxonomy(self.bid_ranks_map)
        jw.set_tree(self.reftree_lbl_str)
        jw.set_outgroup(self.reftree_outgroup)
        jw.set_ratehet_model(self.cfg.raxml_model)
        jw.set_tax_tree(self.reftree_multif)
        jw.set_pattern_compression(self.cfg.compress_patterns)

        mdata = { "ref_tree_size": self.reftree_size, 
                  "ref_alignment_width": self.refalign_width,
                  "raxml_version": self.raxml_version,
                  "timestamp": str(datetime.datetime.now()),
                  "invocation_epac": self.invocation_epac,
                  "invocation_raxml_multif": self.invocation_raxml_multif,
                  "invocation_raxml_optmod": self.invocation_raxml_optmod,
                  "invocation_raxml_epalbl": self.invocation_raxml_epalbl
                }
        jw.set_metadata(mdata)

        seqs = self.reduced_refalign_seqs.get_entries()    
        jw.set_sequences(seqs)
        
        if not self.cfg.no_hmmer:
            self.build_hmm_profile(jw)

        orig_tax = self.taxonomy_map
        jw.set_origin_taxonomy(orig_tax)
        
        print "Calculating the speciation rate...\n"
        tp = tree_param(tree = self.reftree_lbl_str, origin_taxonomy = orig_tax)
        jw.set_rate(tp.get_speciation_rate_fast())
        jw.set_nodes_height(self.node_height_map)
        
        jw.set_binary_model(self.optmod_fname)
        
        print "Writing down the reference file...\n"
        jw.dump(self.cfg.refjson_fname)

    def cleanup(self):
        FileUtils.remove_if_exists(self.outgr_fname)
        FileUtils.remove_if_exists(self.reftree_mfu_fname)
        FileUtils.remove_if_exists(self.reftree_bfu_fname)
        FileUtils.remove_if_exists(self.optmod_fname)
        FileUtils.remove_if_exists(self.lblalign_fname)
        FileUtils.remove_if_exists(self.outgr_fname)
        FileUtils.remove_if_exists(self.reduced_refalign_fname)
        FileUtils.remove_if_exists(self.refalign_fname)

    # top-level function to build a reference tree    
    def build_ref_tree(self):
        start_time = time.time()
        print "\n> Loading taxonomy from file: %s ...\n" % (self.cfg.taxonomy_fname)
        self.taxonomy = GGTaxonomyFile(self.cfg.taxonomy_fname, EpacConfig.REF_SEQ_PREFIX)
        print "\n=> Building a multifurcating tree from taxonomy with %d seqs ...\n" % self.taxonomy.seq_count()
        self.validate_taxonomy()
        self.build_multif_tree()
        print "\n==> Building the reference alignment ...\n"
        self.export_ref_alignment()
        self.export_ref_taxonomy()
        print "\n===> Saving the outgroup for later re-rooting ...\n"
        self.save_rooting()
        print "\n====> RAxML call: resolve multifurcation ...\n"
        self.resolve_multif()
        self.load_reduced_refalign()
        print "\n=====> RAxML-EPA call: labeling the branches ...\n"
        self.epa_branch_labeling()
        print "\n======> Post-processing the EPA tree (re-rooting, taxonomic labeling etc.) ...\n"
        self.epa_post_process()
        self.calc_node_heights()
        
        if self.cfg.verbose:
            print "\n=======> Checking branch labels ...\n"
            print "shared rank names before training: " + repr(self.taxonomy.get_common_ranks())
            print "shared rank names after  training: " + repr(self.mono_index())
        
        print "\n=======> Saving the reference JSON file ...\n"
        self.write_json()
        elapsed_time = time.time() - start_time
        print "\n***********  Done! (%.0f s) **********\n" % elapsed_time
Example #20
0
class RefTreeBuilder:
    def __init__(self, config):
        self.cfg = config
        self.mfresolv_job_name = self.cfg.subst_name("mfresolv_%NAME%")
        self.epalbl_job_name = self.cfg.subst_name("epalbl_%NAME%")
        self.optmod_job_name = self.cfg.subst_name("optmod_%NAME%")
        self.raxml_wrapper = RaxmlWrapper(config)

        self.outgr_fname = self.cfg.tmp_fname("%NAME%_outgr.tre")
        self.reftree_mfu_fname = self.cfg.tmp_fname("%NAME%_mfu.tre")
        self.reftree_bfu_fname = self.cfg.tmp_fname("%NAME%_bfu.tre")
        self.optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
        self.lblalign_fname = self.cfg.tmp_fname("%NAME%_lblq.fa")
        self.reftree_lbl_fname = self.cfg.tmp_fname("%NAME%_lbl.tre")
        self.reftree_tax_fname = self.cfg.tmp_fname("%NAME%_tax.tre")
        self.brmap_fname = self.cfg.tmp_fname("%NAME%_map.txt")

    def load_alignment(self):
        in_file = self.cfg.align_fname
        self.input_seqs = None
        formats = [
            "fasta", "phylip_relaxed", "iphylip_relaxed", "phylip", "iphylip"
        ]
        for fmt in formats:
            try:
                self.input_seqs = SeqGroup(sequences=in_file, format=fmt)
                break
            except:
                self.cfg.log.debug("Guessing input format: not " + fmt)
        if self.input_seqs == None:
            self.cfg.exit_user_error(
                "Invalid input file format: %s\nThe supported input formats are fasta and phylip"
                % in_file)

    def validate_taxonomy(self):
        self.input_validator = InputValidator(self.cfg, self.taxonomy,
                                              self.input_seqs)
        self.input_validator.validate()

    def build_multif_tree(self):
        c = self.cfg

        tb = TaxTreeBuilder(c, self.taxonomy)
        (t, ids) = tb.build(c.reftree_min_rank, c.reftree_max_seqs_per_leaf,
                            c.reftree_clades_to_include,
                            c.reftree_clades_to_ignore)
        self.reftree_ids = frozenset(ids)
        self.reftree_size = len(ids)
        self.reftree_multif = t

        # IMPORTANT: select GAMMA or CAT model based on tree size!
        self.cfg.resolve_auto_settings(self.reftree_size)

        if self.cfg.debug:
            refseq_fname = self.cfg.tmp_fname("%NAME%_seq_ids.txt")
            # list of sequence ids which comprise the reference tree
            with open(refseq_fname, "w") as f:
                for sid in ids:
                    f.write("%s\n" % sid)

            # original tree with taxonomic ranks as internal node labels
            reftax_fname = self.cfg.tmp_fname("%NAME%_mfu_tax.tre")
            t.write(outfile=reftax_fname, format=8)
        #    t.show()

    def export_ref_alignment(self):
        """This function transforms the input alignment in the following way:
           1. Filter out sequences which are not part of the reference tree
           2. Add sequence name prefix (r_)"""

        self.refalign_fname = self.cfg.tmp_fname("%NAME%_matrix.afa")
        with open(self.refalign_fname, "w") as fout:
            for name, seq, comment, sid in self.input_seqs.iter_entries():
                seq_name = EpacConfig.REF_SEQ_PREFIX + name
                if seq_name in self.input_validator.corr_seqid:
                    seq_name = self.input_validator.corr_seqid[seq_name]
                if seq_name in self.reftree_ids:
                    fout.write(">" + seq_name + "\n" + seq + "\n")

        # we do not need the original alignment anymore, so free its memory
        self.input_seqs = None

    def export_ref_taxonomy(self):
        self.taxonomy_map = {}

        for sid, ranks in self.taxonomy.iteritems():
            if sid in self.reftree_ids:
                self.taxonomy_map[sid] = ranks

        if self.cfg.debug:
            tax_fname = self.cfg.tmp_fname("%NAME%_tax.txt")
            with open(tax_fname, "w") as fout:
                for sid, ranks in self.taxonomy_map.iteritems():
                    ranks_str = self.taxonomy.seq_lineage_str(sid)
                    fout.write(sid + "\t" + ranks_str + "\n")

    def save_rooting(self):
        rt = self.reftree_multif

        tax_map = self.taxonomy.get_map()
        self.taxtree_helper = TaxTreeHelper(self.cfg, tax_map)
        self.taxtree_helper.set_mf_rooted_tree(rt)
        outgr = self.taxtree_helper.get_outgroup()
        outgr_size = len(outgr.get_leaves())
        outgr.write(outfile=self.outgr_fname, format=9)
        self.reftree_outgroup = outgr
        self.cfg.log.debug(
            "Outgroup for rooting was saved to: %s, outgroup size: %d",
            self.outgr_fname, outgr_size)

        # remove unifurcation at the root
        if len(rt.children) == 1:
            rt = rt.children[0]

        # now we can safely unroot the tree and remove internal node labels to make it suitable for raxml
        rt.write(outfile=self.reftree_mfu_fname, format=9)

    # RAxML call to convert multifurcating tree to the strictly bifurcating one
    def resolve_multif(self):
        self.cfg.log.debug("\nReducing the alignment: \n")
        self.reduced_refalign_fname = self.raxml_wrapper.reduce_alignment(
            self.refalign_fname)

        self.cfg.log.debug("\nConstrained ML inference: \n")
        raxml_params = [
            "-s", self.reduced_refalign_fname, "-g", self.reftree_mfu_fname,
            "--no-seq-check", "-N",
            str(self.cfg.rep_num)
        ]
        if self.cfg.mfresolv_method == "fast":
            raxml_params += ["-D"]
        elif self.cfg.mfresolv_method == "ultrafast":
            raxml_params += ["-f", "e"]
        if self.cfg.restart and self.raxml_wrapper.result_exists(
                self.mfresolv_job_name):
            self.invocation_raxml_multif = self.raxml_wrapper.get_invocation_str(
                self.mfresolv_job_name)
            self.cfg.log.debug(
                "\nUsing existing ML tree found in: %s\n",
                self.raxml_wrapper.result_fname(self.mfresolv_job_name))
        else:
            self.invocation_raxml_multif = self.raxml_wrapper.run(
                self.mfresolv_job_name, raxml_params)
            #            self.invocation_raxml_multif = self.raxml_wrapper.run_multiple(self.mfresolv_job_name, raxml_params, self.cfg.rep_num)
            if self.cfg.mfresolv_method == "ultrafast":
                self.raxml_wrapper.copy_result_tree(
                    self.mfresolv_job_name,
                    self.raxml_wrapper.besttree_fname(self.mfresolv_job_name))

        if self.raxml_wrapper.besttree_exists(self.mfresolv_job_name):
            if not self.cfg.reopt_model:
                self.raxml_wrapper.copy_best_tree(self.mfresolv_job_name,
                                                  self.reftree_bfu_fname)
                self.raxml_wrapper.copy_optmod_params(self.mfresolv_job_name,
                                                      self.optmod_fname)
                self.invocation_raxml_optmod = ""
                job_name = self.mfresolv_job_name
            else:
                bfu_fname = self.raxml_wrapper.besttree_fname(
                    self.mfresolv_job_name)
                job_name = self.optmod_job_name

                # RAxML call to optimize model parameters and write them down to the binary model file
                self.cfg.log.debug("\nOptimizing model parameters: \n")
                raxml_params = [
                    "-f", "e", "-s", self.reduced_refalign_fname, "-t",
                    bfu_fname, "--no-seq-check"
                ]
                if self.cfg.raxml_model.startswith(
                        "GTRCAT") and not self.cfg.compress_patterns:
                    raxml_params += ["-H"]
                if self.cfg.restart and self.raxml_wrapper.result_exists(
                        self.optmod_job_name):
                    self.invocation_raxml_optmod = self.raxml_wrapper.get_invocation_str(
                        self.optmod_job_name)
                    self.cfg.log.debug(
                        "\nUsing existing optimized tree and parameters found in: %s\n",
                        self.raxml_wrapper.result_fname(self.optmod_job_name))
                else:
                    self.invocation_raxml_optmod = self.raxml_wrapper.run(
                        self.optmod_job_name, raxml_params)
                if self.raxml_wrapper.result_exists(self.optmod_job_name):
                    self.raxml_wrapper.copy_result_tree(
                        self.optmod_job_name, self.reftree_bfu_fname)
                    self.raxml_wrapper.copy_optmod_params(
                        self.optmod_job_name, self.optmod_fname)
                else:
                    errmsg = "RAxML run failed (model optimization), please examine the log for details: %s" \
                            % self.raxml_wrapper.make_raxml_fname("output", self.optmod_job_name)
                    self.cfg.exit_fatal_error(errmsg)

            if self.cfg.raxml_model.startswith("GTRCAT"):
                mod_name = "CAT"
            else:
                mod_name = "GAMMA"
            self.reftree_loglh = self.raxml_wrapper.get_tree_lh(
                job_name, mod_name)
            self.cfg.log.debug("\n%s-based logLH of the reference tree: %f\n" %
                               (mod_name, self.reftree_loglh))

        else:
            errmsg = "RAxML run failed (mutlifurcation resolution), please examine the log for details: %s" \
                    % self.raxml_wrapper.make_raxml_fname("output", self.mfresolv_job_name)
            self.cfg.exit_fatal_error(errmsg)

    def load_reduced_refalign(self):
        formats = ["fasta", "phylip_relaxed"]
        for fmt in formats:
            try:
                self.reduced_refalign_seqs = SeqGroup(
                    sequences=self.reduced_refalign_fname, format=fmt)
                break
            except:
                pass
        if self.reduced_refalign_seqs == None:
            errmsg = "FATAL ERROR: Invalid input file format in %s! (load_reduced_refalign)" % self.reduced_refalign_fname
            self.cfg.exit_fatal_error(errmsg)

    # dummy EPA run to label the branches of the reference tree, which we need to build a mapping to tax ranks
    def epa_branch_labeling(self):
        # create alignment with dummy query seq
        self.refalign_width = len(self.reduced_refalign_seqs.get_seqbyid(0))
        self.reduced_refalign_seqs.write(format="fasta",
                                         outfile=self.lblalign_fname)

        with open(self.lblalign_fname, "a") as fout:
            fout.write(">" + "DUMMY131313" + "\n")
            fout.write("A" * self.refalign_width + "\n")

        # TODO always load model regardless of the config file settings?
        epa_result = self.raxml_wrapper.run_epa(self.epalbl_job_name,
                                                self.lblalign_fname,
                                                self.reftree_bfu_fname,
                                                self.optmod_fname,
                                                mode="epa_mp")
        self.reftree_lbl_str = epa_result.get_std_newick_tree()
        self.raxml_version = epa_result.get_raxml_version()
        self.invocation_raxml_epalbl = epa_result.get_raxml_invocation()

        if not self.raxml_wrapper.epa_result_exists(self.epalbl_job_name):
            errmsg = "RAxML EPA run failed, please examine the log for details: %s" \
                    % self.raxml_wrapper.make_raxml_fname("output", self.epalbl_job_name)
            self.cfg.exit_fatal_error(errmsg)

    def epa_post_process(self):
        lbl_tree = Tree(self.reftree_lbl_str)
        self.taxtree_helper.set_bf_unrooted_tree(lbl_tree)
        self.reftree_tax = self.taxtree_helper.get_tax_tree()
        self.bid_ranks_map = self.taxtree_helper.get_bid_taxonomy_map()

        if self.cfg.debug:
            self.reftree_tax.write(outfile=self.reftree_tax_fname, format=3)
            with open(self.reftree_lbl_fname, "w") as outf:
                outf.write(self.reftree_lbl_str)
            with open(self.brmap_fname, "w") as outf:
                for bid, br_rec in self.bid_ranks_map.iteritems():
                    outf.write("%s\t%s\t%d\t%f\n" %
                               (bid, br_rec[0], br_rec[1], br_rec[2]))

    def calc_node_heights(self):
        """Calculate node heights on the reference tree (used to define branch-length cutoff during classification step)
           Algorithm is as follows:
           Tip node or node resolved to Species level: height = 1 
           Inner node resolved to Genus or above:      height = min(left_height, right_height) + 1 
         """
        nh_map = {}
        dummy_added = False
        for node in self.reftree_tax.traverse("postorder"):
            if not node.is_root():
                if not hasattr(node, "B"):
                    # In a rooted tree, there is always one more node/branch than in unrooted one
                    # That's why one branch will be always not EPA-labelled after the rooting
                    if not dummy_added:
                        node.B = "DDD"
                        dummy_added = True
                        species_rank = Taxonomy.EMPTY_RANK
                    else:
                        errmsg = "FATAL ERROR: More than one tree branch without EPA label (calc_node_heights)"
                        self.cfg.exit_fatal_error(errmsg)
                else:
                    species_rank = self.bid_ranks_map[node.B][-1]
                bid = node.B
                if node.is_leaf() or species_rank != Taxonomy.EMPTY_RANK:
                    nh_map[bid] = 1
                else:
                    lchild = node.children[0]
                    rchild = node.children[1]
                    nh_map[bid] = min(nh_map[lchild.B], nh_map[rchild.B]) + 1

        # remove heights for dummy nodes, since there won't be any placements on them
        if dummy_added:
            del nh_map["DDD"]

        self.node_height_map = nh_map

    def __get_all_rank_names(self, root):
        rnames = set([])
        for node in root.traverse("postorder"):
            ranks = node.ranks
            for rk in ranks:
                rnames.add(rk)
        return rnames

    def mono_index(self):
        """This method will calculate monophyly index by looking at the left and right hand side of the tree"""
        children = self.reftree_tax.children
        if len(children) == 1:
            while len(children) == 1:
                children = children[0].children
        if len(children) == 2:
            left = children[0]
            right = children[1]
            lset = self.__get_all_rank_names(left)
            rset = self.__get_all_rank_names(right)
            iset = lset & rset
            return iset
        else:
            print("Error: input tree not birfurcating")
            return set([])

    def build_hmm_profile(self, json_builder):
        print "Building the HMMER profile...\n"

        # this stupid workaround is needed because RAxML outputs the reduced
        # alignment in relaxed PHYLIP format, which is not supported by HMMER
        refalign_fasta = self.cfg.tmp_fname("%NAME%_ref_reduced.fa")
        self.reduced_refalign_seqs.write(outfile=refalign_fasta)

        hmm = hmmer(self.cfg, refalign_fasta)
        fprofile = hmm.build_hmm_profile()

        json_builder.set_hmm_profile(fprofile)

    def write_json(self):
        jw = RefJsonBuilder()

        jw.set_branch_tax_map(self.bid_ranks_map)
        jw.set_tree(self.reftree_lbl_str)
        jw.set_outgroup(self.reftree_outgroup)
        jw.set_ratehet_model(self.cfg.raxml_model)
        jw.set_tax_tree(self.reftree_multif)
        jw.set_pattern_compression(self.cfg.compress_patterns)
        jw.set_taxcode(self.cfg.taxcode_name)

        jw.set_merged_ranks_map(self.input_validator.merged_ranks)
        corr_ranks_reverse = dict(
            (reversed(item)
             for item in self.input_validator.corr_ranks.items()))
        jw.set_corr_ranks_map(corr_ranks_reverse)
        corr_seqid_reverse = dict(
            (reversed(item)
             for item in self.input_validator.corr_seqid.items()))
        jw.set_corr_seqid_map(corr_seqid_reverse)

        mdata = {
            "ref_tree_size": self.reftree_size,
            "ref_alignment_width": self.refalign_width,
            "raxml_version": self.raxml_version,
            "timestamp": str(datetime.datetime.now()),
            "invocation_epac": self.invocation_epac,
            "invocation_raxml_multif": self.invocation_raxml_multif,
            "invocation_raxml_optmod": self.invocation_raxml_optmod,
            "invocation_raxml_epalbl": self.invocation_raxml_epalbl,
            "reftree_loglh": self.reftree_loglh
        }
        jw.set_metadata(mdata)

        seqs = self.reduced_refalign_seqs.get_entries()
        jw.set_sequences(seqs)

        if not self.cfg.no_hmmer:
            self.build_hmm_profile(jw)

        orig_tax = self.taxonomy_map
        jw.set_origin_taxonomy(orig_tax)

        self.cfg.log.debug("Calculating the speciation rate...\n")
        tp = tree_param(tree=self.reftree_lbl_str, origin_taxonomy=orig_tax)
        jw.set_rate(tp.get_speciation_rate_fast())
        jw.set_nodes_height(self.node_height_map)

        jw.set_binary_model(self.optmod_fname)

        self.cfg.log.debug("Writing down the reference file...\n")
        jw.dump(self.cfg.refjson_fname)

    # top-level function to build a reference tree
    def build_ref_tree(self):
        self.cfg.log.info("=> Loading taxonomy from file: %s ...\n",
                          self.cfg.taxonomy_fname)
        self.taxonomy = Taxonomy(prefix=EpacConfig.REF_SEQ_PREFIX,
                                 tax_fname=self.cfg.taxonomy_fname)
        self.cfg.log.info(
            "==> Loading reference alignment from file: %s ...\n",
            self.cfg.align_fname)
        self.load_alignment()
        self.cfg.log.info("===> Validating taxonomy and alignment ...\n")
        self.validate_taxonomy()
        self.cfg.log.info(
            "====> Building a multifurcating tree from taxonomy with %d seqs ...\n",
            self.taxonomy.seq_count())
        self.build_multif_tree()
        self.cfg.log.info("=====> Building the reference alignment ...\n")
        self.export_ref_alignment()
        self.export_ref_taxonomy()
        self.cfg.log.info(
            "======> Saving the outgroup for later re-rooting ...\n")
        self.save_rooting()
        self.cfg.log.info(
            "=======> Resolving multifurcation: choosing the best topology from %d independent RAxML runs ...\n"
            % self.cfg.rep_num)
        self.resolve_multif()
        self.load_reduced_refalign()
        self.cfg.log.info(
            "========> Calling RAxML-EPA to obtain branch labels ...\n")
        self.epa_branch_labeling()
        self.cfg.log.info(
            "=========> Post-processing the EPA tree (re-rooting, taxonomic labeling etc.) ...\n"
        )
        self.epa_post_process()
        self.calc_node_heights()

        self.cfg.log.debug("\n==========> Checking branch labels ...")
        self.cfg.log.debug("shared rank names before training: %s",
                           repr(self.taxonomy.get_common_ranks()))
        self.cfg.log.debug("shared rank names after  training: %s\n",
                           repr(self.mono_index()))

        self.cfg.log.info("==========> Saving the reference JSON file: %s\n" %
                          self.cfg.refjson_fname)
        self.write_json()
Example #21
0
class EpaClassifier:
    def __init__(self, config, args):
        self.cfg = config
        self.jplace_fname = args.jplace_fname
        self.ignore_refalign = args.ignore_refalign

        self.tmp_refaln = config.tmp_fname("%NAME%.refaln")
        #here is the final alignment file for running EPA
        self.epa_alignment = config.tmp_fname("%NAME%.afa")
        self.hmmprofile = config.tmp_fname("%NAME%.hmmprofile")
        self.tmpquery = config.tmp_fname("%NAME%.tmpquery")
        self.noalign = config.tmp_fname("%NAME%.noalign")
        self.seqs = None

        assign_fname = args.output_name + ".assignment.txt"
        self.out_assign_fname = os.path.join(args.output_dir, assign_fname)
        jplace_fname = args.output_name + ".jplace"
        self.out_jplace_fname = os.path.join(args.output_dir, jplace_fname)

        try:
            self.refjson = RefJsonParser(config.refjson_fname)
        except ValueError:
            self.cfg.exit_user_error("Invalid json file format: %s" %
                                     config.refjson_fname)
        #validate input json format
        self.refjson.validate()
        self.reftree = self.refjson.get_reftree()
        self.rate = self.refjson.get_rate()
        self.node_height = self.refjson.get_node_height()
        self.cfg.compress_patterns = self.refjson.get_pattern_compression()

        self.bid_taxonomy_map = self.refjson.get_branch_tax_map()
        if not self.bid_taxonomy_map:
            # old file format (before 1.6), need to rebuild this map from scratch
            th = TaxTreeHelper(self.cfg, self.refjson.get_origin_taxonomy())
            th.set_mf_rooted_tree(self.refjson.get_tax_tree())
            th.set_bf_unrooted_tree(self.refjson.get_reftree())
            self.bid_taxonomy_map = th.get_bid_taxonomy_map()

        self.cfg.log.info("Loaded reference tree with %d taxa\n" %
                          len(self.reftree.get_leaves()))

        self.classify_helper = TaxClassifyHelper(self.cfg,
                                                 self.bid_taxonomy_map,
                                                 self.rate, self.node_height)

    def require_muscle(self):
        basepath = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(basepath + "/epac/bin/muscle"):
            errmsg = "The pipeline uses MUSCLE to merge alignments, please download the programm from:\n" + \
                     "http://www.drive5.com/muscle/downloads.htm\n" + \
                     "and specify path to your installation in the config file (sativa.cfg)\n"
            self.cfg.exit_user_error(errmsg)

    def require_hmmer(self):
        basepath = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(basepath +
                              "/epac/bin/hmmbuild") or not os.path.exists(
                                  basepath + "/epac/bin/hmmalign"):
            errmsg = "The pipeline uses HAMMER to align the query seqeunces, please download the programm from:\n" + \
                     "http://hmmer.janelia.org/\n" + \
                     "and specify path to your installation in the config file (sativa.cfg)\n"
            self.cfg.exit_user_error(errmsg)

    def align_to_refenence(self, noalign, minp=0.9):
        refaln = self.refjson.get_alignment(fout=self.tmp_refaln)
        fprofile = self.refjson.get_hmm_profile(self.hmmprofile)

        # if there is no hmmer profile in json file, build it from scratch
        if not fprofile:
            hmm = hmmer(self.cfg, refaln)
            fprofile = hmm.build_hmm_profile()

        hm = hmmer(config=self.cfg,
                   refalign=refaln,
                   query=self.tmpquery,
                   refprofile=fprofile,
                   discard=noalign,
                   seqs=self.seqs,
                   minp=minp)
        self.epa_alignment = hm.align()

    def merge_alignment(self, query_seqs):
        refaln = self.refjson.get_alignment_list()
        with open(self.epa_alignment, "w") as fout:
            for seq in refaln:
                fout.write(">" + seq[0] + "\n" + seq[1] + "\n")
            for name, seq, comment, sid in query_seqs.iter_entries():
                fout.write(">" + name + "\n" + seq + "\n")

    def write_combined_alignment(self):
        self.query_count = 0
        with open(self.epa_alignment, "w") as fout:
            for name, seq, comment, sid in self.seqs.iter_entries():
                ref_name = self.refjson.get_corr_seqid(
                    EpacConfig.REF_SEQ_PREFIX + name)
                if ref_name in self.refjson.get_sequences_names():
                    seq_name = ref_name
                else:
                    seq_name = EpacConfig.QUERY_SEQ_PREFIX + name
                    self.query_count += 1
                fout.write(">" + seq_name + "\n" + seq + "\n")

    def checkinput(self, query_fname, minp=0.9):
        formats = [
            "fasta", "phylip", "iphylip", "phylip_relaxed", "iphylip_relaxed"
        ]
        for fmt in formats:
            try:
                self.seqs = SeqGroup(sequences=query_fname, format=fmt)
                break
            except:
                self.cfg.log.debug("Guessing input format: not " + fmt)
        if self.seqs == None:
            self.cfg.exit_user_error(
                "Invalid input file format: %s\nThe supported input formats are fasta and phylip"
                % query_fname)

        if self.ignore_refalign:
            self.cfg.log.info(
                "Assuming query file contains reference sequences, skipping the alignment step...\n"
            )
            self.write_combined_alignment()
            return

        self.query_count = len(self.seqs)

        # add query seq name prefix to avoid confusion between reference and query sequences
        self.seqs.add_name_prefix(EpacConfig.QUERY_SEQ_PREFIX)

        self.seqs.write(format="fasta", outfile=self.tmpquery)
        self.cfg.log.info("Checking if query sequences are aligned ...")
        entries = self.seqs.get_entries()
        seql = len(entries[0][1])
        aligned = True
        for entri in entries[1:]:
            l = len(entri[1])
            if not seql == l:
                aligned = False
                break

        if aligned and len(self.seqs) > 1:
            self.cfg.log.info("Query sequences are aligned")
            refalnl = self.refjson.get_alignment_length()
            if refalnl == seql:
                self.cfg.log.info(
                    "Merging query alignment with reference alignment")
                self.merge_alignment(self.seqs)
            else:
                self.cfg.log.info(
                    "Merging query alignment with reference alignment using MUSCLE"
                )
                self.require_muscle()
                refaln = self.refjson.get_alignment(fout=self.tmp_refaln)
                m = muscle(self.cfg)
                self.epa_alignment = m.merge(refaln, self.tmpquery)
        else:
            self.cfg.log.info("Query sequences are not aligned")
            self.cfg.log.info(
                "Align query sequences to the reference alignment using HMMER")
            self.require_hmmer()
            self.align_to_refenence(self.noalign, minp=minp)

    def print_ranks(self, rks, confs, minlw=0.0):
        uncorr_ranks = self.refjson.get_uncorr_ranks(rks)
        ss = ""
        css = ""
        for i in range(len(uncorr_ranks)):
            conf = confs[i]
            if conf == confs[0] and confs[0] >= 0.99:
                conf = 1.0
            if conf >= minlw:
                ss = ss + uncorr_ranks[i] + ";"
                css = css + "{0:.3f}".format(conf) + ";"
            else:
                break
        if ss == "":
            return None
        else:
            return ss[:-1] + "\t" + css[:-1]

    def run_epa(self):
        self.cfg.log.info(
            "Running RAxML-EPA to place %d query sequences...\n" %
            self.query_count)
        raxml = RaxmlWrapper(config)
        reftree_fname = self.cfg.tmp_fname("ref_%NAME%.tre")
        self.refjson.get_raxml_readable_tree(reftree_fname)
        optmod_fname = self.cfg.tmp_fname("%NAME%.opt")
        self.refjson.get_binary_model(optmod_fname)
        job_name = self.cfg.subst_name("epa_%NAME%")

        reftree_str = self.refjson.get_raxml_readable_tree()
        reftree = Tree(reftree_str)

        self.reftree_size = len(reftree.get_leaves())

        # IMPORTANT: set EPA heuristic rate based on tree size!
        self.cfg.resolve_auto_settings(self.reftree_size)
        # If we're loading the pre-optimized model, we MUST set the same rate het. mode as in the ref file
        if self.cfg.epa_load_optmod:
            self.cfg.raxml_model = self.refjson.get_ratehet_model()

        reduced_align_fname = raxml.reduce_alignment(self.epa_alignment)

        jp = raxml.run_epa(job_name, reduced_align_fname, reftree_fname,
                           optmod_fname)

        raxml.copy_epa_jplace(job_name, self.out_jplace_fname, move=True)

        return jp

    def run_ptp(self, jp):
        full_aln = SeqGroup(self.epa_alignment)
        species_list = epa_2_ptp(epa_jp=jp,
                                 ref_jp=self.refjson,
                                 full_alignment=full_aln,
                                 min_lw=0.5,
                                 debug=self.cfg.debug)

        self.cfg.log.debug("Species clusters:")

        if fout:
            fo2 = open(fout + ".species", "w")
        else:
            fo2 = None

        for sp_cluster in species_list:
            translated_taxa = []
            for taxon in sp_cluster:
                origin_taxon_name = EpacConfig.strip_query_prefix(taxon)
                translated_taxa.append(origin_taxon_name)
            s = ",".join(translated_taxa)
            if fo2:
                fo2.write(s + "\n")
            self.cfg.log.debug(s)

        if fo2:
            fo2.close()

    def print_result_line(self, fo, line):
        if self.cfg.verbose:
            print(line)
        if fo:
            fo.write(line + "\n")

    def get_noalign_list(self):
        noalign_list = []
        if os.path.exists(self.noalign):
            with open(self.noalign) as fnoa:
                lines = fnoa.readlines()
                for line in lines:
                    taxon_name = line.strip()[1:]
                    origin_taxon_name = EpacConfig.strip_query_prefix(
                        taxon_name)
                    noalign_list.append(origin_taxon_name)
        return noalign_list

    def classify(self, query_fname, minp=0.9, ptp=False):
        if self.jplace_fname:
            jp = EpaJsonParser(self.jplace_fname)
        else:
            self.checkinput(query_fname, minp)
            jp = self.run_epa()

        self.cfg.log.info(
            "Assigning taxonomic labels based on EPA placements...\n")

        placements = jp.get_placement()

        if self.out_assign_fname:
            fo = open(self.out_assign_fname, "w")
        else:
            fo = None

        noassign_list = []
        for place in placements:
            taxon_name = place["n"][0]
            origin_taxon_name = EpacConfig.strip_query_prefix(taxon_name)
            edges = place["p"]

            ranks, lws = self.classify_helper.classify_seq(edges)
            rankout = self.print_ranks(ranks, lws, self.cfg.min_lhw)

            if rankout == None:
                noassign_list.append(origin_taxon_name)
            else:
                output = "%s\t%s\t" % (origin_taxon_name, rankout)
                if self.cfg.check_novelty:
                    isnovo = self.novelty_check(place_edge=str(edges[0][0]),
                                                ranks=ranks,
                                                lws=lws)
                    output += "*" if isnovo else "o"
                self.print_result_line(fo, output)

        noassign_list += self.get_noalign_list()

        for taxon_name in noassign_list:
            output = "%s\t\t\t?" % origin_taxon_name
            self.print_result_line(fo, output)

        if fo:
            fo.close()

        #############################################
        #
        # EPA-PTP species delimitation
        #
        #############################################
        if ptp:
            self.run_ptp(jp)

    def novelty_check(self, place_edge, ranks, lws):
        """If the taxonomic assignment is not assigned to the genus level, 
        we need to check if it is due to the incomplete reference taxonomy or 
        it is likely to be something new:
        
        1. If the final ranks are assinged because of lw cut, that means with samller lw
        the ranks can be further assinged to lowers. This indicate the undetermined ranks 
        in the assignment is not due to the incomplete reference taxonomy, so the query 
        sequence is likely to be something new.
        
        2. Otherwise We check all leaf nodes' immediate lower rank below this ml placement point, 
        if they are not empty, output all ranks and indicate this could be novelty.
        """

        lowrank = 0
        for i in max(range(len(ranks)), 6):
            """above genus level"""
            rk = ranks[i]
            lw = lws[i]
            if rk == "-":
                break
            else:
                lowrank = lowrank + 1
                if lw >= 0 and lw < self.cfg.min_lhw:
                    return True

        if lowrank >= 5 and lowrank < len(ranks) and not ranks[lowrank] == "-":
            return False
        else:
            placenode = self.reftree.search_nodes(B=place_edge)[0]
            if placenode.is_leaf():
                return False
            else:
                leafnodes = placenode.get_leaves()
                flag = True
                for leaf in leafnodes:
                    br_num = leaf.B
                    branks = self.bid_taxonomy_map[br_num]
                    if lowrank >= len(branks) or branks[lowrank] == "-":
                        flag = False
                        break

                return flag