Пример #1
0
def test_ancestral():
    import os
    from Bio import AlignIO
    import numpy as np
    from treetime import TreeAnc, GTR
    root_dir = os.path.dirname(os.path.realpath(__file__))
    fasta = str(os.path.join(root_dir, '../data/H3N2_NA_allyears_NA.20.fasta'))
    nwk = str(os.path.join(root_dir, '../data/H3N2_NA_allyears_NA.20.nwk'))

    for marginal in [True, False]:
        print('loading flu example')
        t = TreeAnc(gtr='Jukes-Cantor', tree=nwk, aln=fasta)
        print('ancestral reconstruction' + ("marginal" if marginal else "joint"))
        t.reconstruct_anc(method='ml', marginal=marginal)
        assert "".join(t.tree.root.sequence) == 'ATGAATCCAAATCAAAAGATAATAACGATTGGCTCTGTTTCTCTCACCATTTCCACAATATGCTTCTTCATGCAAATTGCCATCTTGATAACTACTGTAACATTGCATTTCAAGCAATATGAATTCAACTCCCCCCCAAACAACCAAGTGATGCTGTGTGAACCAACAATAATAGAAAGAAACATAACAGAGATAGTGTATCTGACCAACACCACCATAGAGAAGGAAATATGCCCCAAACCAGCAGAATACAGAAATTGGTCAAAACCGCAATGTGGCATTACAGGATTTGCACCTTTCTCTAAGGACAATTCGATTAGGCTTTCCGCTGGTGGGGACATCTGGGTGACAAGAGAACCTTATGTGTCATGCGATCCTGACAAGTGTTATCAATTTGCCCTTGGACAGGGAACAACACTAAACAACGTGCATTCAAATAACACAGTACGTGATAGGACCCCTTATCGGACTCTATTGATGAATGAGTTGGGTGTTCCTTTTCATCTGGGGACCAAGCAAGTGTGCATAGCATGGTCCAGCTCAAGTTGTCACGATGGAAAAGCATGGCTGCATGTTTGTATAACGGGGGATGATAAAAATGCAACTGCTAGCTTCATTTACAATGGGAGGCTTGTAGATAGTGTTGTTTCATGGTCCAAAGAAATTCTCAGGACCCAGGAGTCAGAATGCGTTTGTATCAATGGAACTTGTACAGTAGTAATGACTGATGGAAGTGCTTCAGGAAAAGCTGATACTAAAATACTATTCATTGAGGAGGGGAAAATCGTTCATACTAGCACATTGTCAGGAAGTGCTCAGCATGTCGAAGAGTGCTCTTGCTATCCTCGATATCCTGGTGTCAGATGTGTCTGCAGAGACAACTGGAAAGGCTCCAATCGGCCCATCGTAGATATAAACATAAAGGATCATAGCATTGTTTCCAGTTATGTGTGTTCAGGACTTGTTGGAGACACACCCAGAAAAAACGACAGCTCCAGCAGTAGCCATTGTTTGGATCCTAACAATGAAGAAGGTGGTCATGGAGTGAAAGGCTGGGCCTTTGATGATGGAAATGACGTGTGGATGGGAAGAACAATCAACGAGACGTCACGCTTAGGGTATGAAACCTTCAAAGTCATTGAAGGCTGGTCCAACCCTAAGTCCAAATTGCAGATAAATAGGCAAGTCATAGTTGACAGAGGTGATAGGTCCGGTTATTCTGGTATTTTCTCTGTTGAAGGCAAAAGCTGCATCAATCGGTGCTTTTATGTGGAGTTGATTAGGGGAAGAAAAGAGGAAACTGAAGTCTTGTGGACCTCAAACAGTATTGTTGTGTTTTGTGGCACCTCAGGTACATATGGAACAGGCTCATGGCCTGATGGGGCGGACCTCAATCTCATGCCTATA'

    print('testing LH normalization')
    from StringIO import StringIO
    from Bio import Phylo,AlignIO
    tiny_tree = Phylo.read(StringIO("((A:0.60100000009,B:0.3010000009):0.1,C:0.2):0.001;"), 'newick')
    tiny_aln = AlignIO.read(StringIO(">A\nAAAAAAAAAAAAAAAACCCCCCCCCCCCCCCCGGGGGGGGGGGGGGGGTTTTTTTTTTTTTTTT\n"
                                     ">B\nAAAACCCCGGGGTTTTAAAACCCCGGGGTTTTAAAACCCCGGGGTTTTAAAACCCCGGGGTTTT\n"
                                     ">C\nACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGT\n"), 'fasta')

    mygtr = GTR.custom(alphabet = np.array(['A', 'C', 'G', 'T']), pi = np.array([0.9, 0.06, 0.02, 0.02]), W=np.ones((4,4)))
    t = TreeAnc(gtr=mygtr, tree=tiny_tree, aln=tiny_aln)
    t.reconstruct_anc('ml', marginal=True, debug=True)
    lhsum =  (t.tree.root.marginal_profile.sum(axis=1) * np.exp(t.tree.root.marginal_subtree_LH_prefactor)).sum()
    print (lhsum)
    assert(np.abs(lhsum-1.0)<1e-6)

    t.optimize_branch_len()
Пример #2
0
def test_ancestral():
    import os
    from Bio import AlignIO
    import numpy as np
    from treetime import TreeAnc, GTR
    root_dir = os.path.dirname(os.path.realpath(__file__))
    fasta = str(os.path.join(root_dir, 'treetime_examples/data/h3n2_na/h3n2_na_20.fasta'))
    nwk = str(os.path.join(root_dir, 'treetime_examples/data/h3n2_na/h3n2_na_20.nwk'))

    for marginal in [True, False]:
        print('loading flu example')
        t = TreeAnc(gtr='Jukes-Cantor', tree=nwk, aln=fasta)
        print('ancestral reconstruction' + ("marginal" if marginal else "joint"))
        t.reconstruct_anc(method='ml', marginal=marginal)
        assert "".join(t.tree.root.sequence) == 'ATGAATCCAAATCAAAAGATAATAACGATTGGCTCTGTTTCTCTCACCATTTCCACAATATGCTTCTTCATGCAAATTGCCATCTTGATAACTACTGTAACATTGCATTTCAAGCAATATGAATTCAACTCCCCCCCAAACAACCAAGTGATGCTGTGTGAACCAACAATAATAGAAAGAAACATAACAGAGATAGTGTATCTGACCAACACCACCATAGAGAAGGAAATATGCCCCAAACCAGCAGAATACAGAAATTGGTCAAAACCGCAATGTGGCATTACAGGATTTGCACCTTTCTCTAAGGACAATTCGATTAGGCTTTCCGCTGGTGGGGACATCTGGGTGACAAGAGAACCTTATGTGTCATGCGATCCTGACAAGTGTTATCAATTTGCCCTTGGACAGGGAACAACACTAAACAACGTGCATTCAAATAACACAGTACGTGATAGGACCCCTTATCGGACTCTATTGATGAATGAGTTGGGTGTTCCTTTTCATCTGGGGACCAAGCAAGTGTGCATAGCATGGTCCAGCTCAAGTTGTCACGATGGAAAAGCATGGCTGCATGTTTGTATAACGGGGGATGATAAAAATGCAACTGCTAGCTTCATTTACAATGGGAGGCTTGTAGATAGTGTTGTTTCATGGTCCAAAGAAATTCTCAGGACCCAGGAGTCAGAATGCGTTTGTATCAATGGAACTTGTACAGTAGTAATGACTGATGGAAGTGCTTCAGGAAAAGCTGATACTAAAATACTATTCATTGAGGAGGGGAAAATCGTTCATACTAGCACATTGTCAGGAAGTGCTCAGCATGTCGAAGAGTGCTCTTGCTATCCTCGATATCCTGGTGTCAGATGTGTCTGCAGAGACAACTGGAAAGGCTCCAATCGGCCCATCGTAGATATAAACATAAAGGATCATAGCATTGTTTCCAGTTATGTGTGTTCAGGACTTGTTGGAGACACACCCAGAAAAAACGACAGCTCCAGCAGTAGCCATTGTTTGGATCCTAACAATGAAGAAGGTGGTCATGGAGTGAAAGGCTGGGCCTTTGATGATGGAAATGACGTGTGGATGGGAAGAACAATCAACGAGACGTCACGCTTAGGGTATGAAACCTTCAAAGTCATTGAAGGCTGGTCCAACCCTAAGTCCAAATTGCAGATAAATAGGCAAGTCATAGTTGACAGAGGTGATAGGTCCGGTTATTCTGGTATTTTCTCTGTTGAAGGCAAAAGCTGCATCAATCGGTGCTTTTATGTGGAGTTGATTAGGGGAAGAAAAGAGGAAACTGAAGTCTTGTGGACCTCAAACAGTATTGTTGTGTTTTGTGGCACCTCAGGTACATATGGAACAGGCTCATGGCCTGATGGGGCGGACCTCAATCTCATGCCTATA'

    print('testing LH normalization')
    from Bio import Phylo,AlignIO
    tiny_tree = Phylo.read(StringIO("((A:0.60100000009,B:0.3010000009):0.1,C:0.2):0.001;"), 'newick')
    tiny_aln = AlignIO.read(StringIO(">A\nAAAAAAAAAAAAAAAACCCCCCCCCCCCCCCCGGGGGGGGGGGGGGGGTTTTTTTTTTTTTTTT\n"
                                     ">B\nAAAACCCCGGGGTTTTAAAACCCCGGGGTTTTAAAACCCCGGGGTTTTAAAACCCCGGGGTTTT\n"
                                     ">C\nACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGT\n"), 'fasta')

    mygtr = GTR.custom(alphabet = np.array(['A', 'C', 'G', 'T']), pi = np.array([0.9, 0.06, 0.02, 0.02]), W=np.ones((4,4)))
    t = TreeAnc(gtr=mygtr, tree=tiny_tree, aln=tiny_aln)
    t.reconstruct_anc('ml', marginal=True, debug=True)
    lhsum =  np.exp(t.sequence_LH(pos=np.arange(4**3))).sum()
    print (lhsum)
    assert(np.abs(lhsum-1.0)<1e-6)

    t.optimize_branch_len()
def infer_gene_gain_loss(path, rates=[1.0, 1.0]):
    # initialize GTR model with default parameters
    mu = np.sum(rates)
    gene_pi = np.array(rates) / mu
    gain_loss_model = GTR.custom(pi=gene_pi,
                                 mu=mu,
                                 W=np.ones((2, 2)),
                                 alphabet=np.array(['0', '1']))
    # add "unknown" state to profile
    gain_loss_model.profile_map['-'] = np.ones(2)
    root_dir = os.path.dirname(os.path.realpath(__file__))

    # define file names for pseudo alignment of presence/absence patterns as in 001001010110
    sep = '/'
    fasta = sep.join([path.rstrip(sep), 'geneCluster', 'genePresence.aln'])
    # strain tree based on core gene SNPs
    nwk = sep.join([path.rstrip(sep), 'geneCluster', 'strain_tree.nwk'])

    # instantiate treetime with custom GTR
    t = TreeAnc(nwk, gtr=gain_loss_model, verbose=2)
    # fix leaves names since Bio.Phylo interprets numeric leaf names as confidence
    for leaf in t.tree.get_terminals():
        if leaf.name is None:
            leaf.name = str(leaf.confidence)
    t.aln = fasta
    t.tree.root.branch_length = 0.0001
    t.reconstruct_anc(method='ml')

    for n in t.tree.find_clades():
        n.genepresence = n.sequence

    return t
Пример #4
0
    def real_lh():
        """
        Likelihood of the sequences calculated by the joint ancestral
        sequence reconstruction
        """
        tiny_aln_1 = AlignIO.read(StringIO(">A\n"+A_char+"\n"
                                           ">B\n"+B_char+"\n"
                                           ">D\n"+D_char+"\n"), 'fasta')

        myTree_1 = TreeAnc(gtr=mygtr, tree = tiny_tree,
                            aln=tiny_aln_1, verbose = 4)

        myTree_1.reconstruct_anc(method='ml', marginal=False, debug=True)
        logLH = myTree_1.tree.sequence_LH
        return logLH
Пример #5
0
    def real_lh():
        """
        Likelihood of the sequences calculated by the joint ancestral
        sequence reconstruction
        """
        tiny_aln_1 = AlignIO.read(StringIO(">A\n"+A_char+"\n"
                                           ">B\n"+B_char+"\n"
                                           ">D\n"+D_char+"\n"), 'fasta')

        myTree_1 = TreeAnc(gtr=mygtr, tree = tiny_tree,
                            aln=tiny_aln_1, verbose = 4)

        myTree_1.reconstruct_anc(method='ml', marginal=False, debug=True)
        logLH = myTree_1.tree.sequence_LH
        return logLH
class mpm_tree(object):
    '''
    class that aligns a set of sequences and infers a tree
    '''
    def __init__(self, cluster_seq_filepath, **kwarks):
        self.clusterID = cluster_seq_filepath.split('/')[-1].split('.fna')[0]
        if 'speciesID' in kwarks:
            folderID = kwarks['speciesID']
        else:
            folderID = cluster_seq_filepath.split('/')[-3]
        self.seqs = {
            x.id: x
            for x in SeqIO.parse(cluster_seq_filepath, 'fasta')
        }
        if 'run_dir' not in kwarks:
            import random
            #self.run_dir = '_'.join(['tmp', self.clusterID])
            self.run_dir = 'tmp/'
            self.run_dir += '_'.join([
                folderID, 'tmp',
                time.strftime('%H%M%S', time.gmtime()),
                str(random.randint(0, 100000000))
            ])
        else:
            self.run_dir = kwarks['run_dir']
        self.nuc = True

    def codon_align(self,
                    alignment_tool="mafft",
                    prune=True,
                    discard_premature_stops=False):
        '''
        takes a nucleotide alignment, translates it, aligns the amino acids, pads the gaps
        note that this suppresses any compensated frameshift mutations

        Parameters:
        - alignment_tool: ['mafft', 'muscle'] the commandline tool to use
        '''
        cwd = os.getcwd()
        make_dir(self.run_dir)
        os.chdir(self.run_dir)

        # translate
        aa_seqs = {}
        for seq in self.seqs.values():
            tempseq = seq.seq.translate(table="Bacterial")
            # use only sequences that translate without trouble
            if not discard_premature_stops or '*' not in str(
                    tempseq)[:-1] or prune == False:
                aa_seqs[seq.id] = SeqRecord(tempseq, id=seq.id)
            else:
                print(seq.id, "has premature stops, discarding")

        tmpfname = 'temp_in.fasta'
        SeqIO.write(aa_seqs.values(), tmpfname, 'fasta')

        if alignment_tool == 'mafft':
            os.system(
                'mafft --reorder --amino temp_in.fasta 1> temp_out.fasta')
            aln_aa = AlignIO.read('temp_out.fasta', "fasta")
        elif alignment_tool == 'muscle':
            from Bio.Align.Applications import MuscleCommandline
            cline = MuscleCommandline(input=tmpfname,
                                      out=tmpfname[:-5] + 'aligned.fasta')
            cline()
            aln_aa = AlignIO.read(tmpfname[:-5] + 'aligned.fasta', "fasta")
        else:
            print 'Alignment tool not supported:' + alignment_tool
            #return

        #generate nucleotide alignment
        self.aln = pad_nucleotide_sequences(aln_aa, self.seqs)
        os.chdir(cwd)
        remove_dir(self.run_dir)

    def align(self):
        '''
        align sequencences in self.seqs using mafft
        '''
        cwd = os.getcwd()
        make_dir(self.run_dir)
        os.chdir(self.run_dir)

        SeqIO.write(self.seqs.values(), "temp_in.fasta", "fasta")
        os.system(
            'mafft --reorder --anysymbol temp_in.fasta 1> temp_out.fasta 2> mafft.log'
        )

        self.aln = AlignIO.read('temp_out.fasta', 'fasta')
        os.chdir(cwd)
        remove_dir(self.run_dir)

    def build(self,
              root='midpoint',
              raxml=True,
              fasttree_program='fasttree',
              raxml_time_limit=0.5,
              treetime_used=True):
        '''
        build a phylogenetic tree using fasttree and raxML (optional)
        based on nextflu tree building pipeline
        '''
        import subprocess
        cwd = os.getcwd()
        make_dir(self.run_dir)
        os.chdir(self.run_dir)
        AlignIO.write(self.aln, 'origin.fasta', 'fasta')
        name_translation = make_strains_unique(self.aln)
        AlignIO.write(self.aln, 'temp.fasta', 'fasta')

        tree_cmd = [fasttree_program]
        if self.nuc: tree_cmd.append("-nt")
        tree_cmd.append("temp.fasta 1> initial_tree.newick 2> fasttree.log")
        os.system(" ".join(tree_cmd))

        out_fname = "tree_infer.newick"

        if raxml == False:
            #shutil.copy('initial_tree.newick', out_fname)
            polytomies_midpointRooting('initial_tree.newick', out_fname,
                                       self.clusterID)
        elif len(set([x.id for x in SeqIO.parse('temp.fasta', 'fasta')])) > 3:
            ## only for tree with >3 strains
            if raxml_time_limit > 0:
                tmp_tree = Phylo.read('initial_tree.newick', 'newick')
                resolve_iter = 0
                resolve_polytomies(tmp_tree)
                while (not tmp_tree.is_bifurcating()) and (resolve_iter < 10):
                    resolve_iter += 1
                    resolve_polytomies(tmp_tree)
                Phylo.write(tmp_tree, 'initial_tree.newick', 'newick')
                AlignIO.write(self.aln, "temp.phyx", "phylip-relaxed")
                print("RAxML tree optimization with time limit",
                      raxml_time_limit, "hours")
                # using exec to be able to kill process
                end_time = time.time() + int(raxml_time_limit * 3600)
                process = subprocess.Popen(
                    "exec raxml -f d -j -s temp.phyx -n topology -c 25 -m GTRCAT -p 344312987 -t initial_tree.newick",
                    shell=True)
                while (time.time() < end_time):
                    if os.path.isfile('RAxML_result.topology'):
                        break
                    time.sleep(10)
                process.terminate()

                checkpoint_files = glob.glob("RAxML_checkpoint*")
                if os.path.isfile('RAxML_result.topology'):
                    checkpoint_files.append('RAxML_result.topology')
                if len(checkpoint_files) > 0:
                    last_tree_file = checkpoint_files[-1]
                    shutil.copy(last_tree_file, 'raxml_tree.newick')
                else:
                    shutil.copy("initial_tree.newick", 'raxml_tree.newick')
            else:
                shutil.copy("initial_tree.newick", 'raxml_tree.newick')

            try:
                print("RAxML branch length optimization")
                os.system(
                    "raxml -f e -s temp.phyx -n branches -c 25 -m GTRGAMMA -p 344312987 -t raxml_tree.newick"
                )
                shutil.copy('RAxML_result.branches', out_fname)
            except:
                print("RAxML branch length optimization failed")
                shutil.copy('raxml_tree.newick', out_fname)

        if treetime_used:
            # load the resulting tree as a treetime instance
            from treetime import TreeAnc
            self.tt = TreeAnc(tree=out_fname,
                              aln=self.aln,
                              gtr='Jukes-Cantor',
                              verbose=0)
            # provide short cut to tree and revert names that conflicted with newick format
            self.tree = self.tt.tree
        else:
            self.tree = Phylo.read(out_fname, 'newick')
        self.tree.root.branch_length = 0.0001
        restore_strain_name(name_translation, self.aln)
        restore_strain_name(name_translation, self.tree.get_terminals())

        for node in self.tree.find_clades():
            if node.name is not None:
                if node.name.startswith('NODE_') == False:
                    node.ann = node.name
            else:
                node.name = 'NODE_0'

        os.chdir(cwd)
        remove_dir(self.run_dir)
        self.is_timetree = False

    def ancestral(self, translate_tree=False):
        '''
        infer ancestral nucleotide sequences using maximum likelihood
        and translate the resulting sequences (+ terminals) to amino acids
        '''
        try:
            self.tt.reconstruct_anc(method='ml')
        except:
            print "trouble at self.tt.reconstruct_anc(method='ml')"

        if translate_tree:
            for node in self.tree.find_clades():
                node.aa_sequence = np.fromstring(str(
                    self.translate_seq("".join(node.sequence))),
                                                 dtype='S1')

    def refine(self, CDS=True):
        '''
        determine mutations on each branch and attach as string to the branches
        '''
        for node in self.tree.find_clades():

            if node.up is not None:
                node.muts = ",".join([
                    "".join(map(str, x)) for x in node.mutations
                    if '-' not in x
                ])
                if CDS == True:
                    node.aa_muts = ",".join([
                        anc + str(pos + 1) + der
                        for pos, (anc, der) in enumerate(
                            zip(node.up.aa_sequence, node.aa_sequence))
                        if anc != der and '-' not in anc and '-' not in der
                    ])

    def translate_seq(self, seq):
        '''
        custom translation sequence that handles gaps
        '''
        if type(seq) != str:
            str_seq = str(seq.seq)
        else:
            str_seq = seq
        try:
            # soon not needed as future biopython version will translate --- into -
            tmp_seq = Seq(
                str(
                    Seq(str_seq.replace('---', 'NNN')).translate(
                        table="Bacterial")).replace('X', '-'))
        except:
            tmp_seq = Seq(
                str(
                    Seq(str_seq.replace(
                        '-',
                        'N')).translate(table="Bacterial")).replace('X', '-'))
        return tmp_seq

    def translate(self):
        '''
        translate the nucleotide alignment to an amino acid alignment
        '''
        aa_seqs = []
        for seq in self.aln:
            aa_seqs.append(
                SeqRecord(seq=self.translate_seq(seq),
                          id=seq.id,
                          name=seq.name,
                          description=seq.description))
        self.aa_aln = MultipleSeqAlignment(aa_seqs)

    def mean_std_seqLen(self):
        """ returen mean and standard deviation of sequence lengths """
        seqLen_arr = np.array([len(seq) for seq in self.seqs.values()])
        return np.mean(seqLen_arr, axis=0), np.std(seqLen_arr, axis=0)

    def paralogy_statistics(self):
        best_split = find_best_split(self.tree)
        return len(best_split.para_nodes), best_split.branch_length

    def diversity_statistics_nuc(self):
        ''' calculate alignment entropy of nucleotide alignments '''
        TINY = 1e-10
        if not hasattr(self, "aln"):
            print("calculate alignment first")
            return
        self.af_nuc = calc_af(self.aln, nuc_alpha)
        is_valid = self.af_nuc[:-2].sum(axis=0) > 0.5
        tmp_af = self.af_nuc[:-2, is_valid] / self.af_nuc[:-2,
                                                          is_valid].sum(axis=0)
        #self.entropy_nuc = np.mean(-(tmp_af*np.log(tmp_af+TINY)).sum(axis=0))
        self.diversity_nuc = np.mean(1.0 - (tmp_af**2).sum(axis=0))

    def diversity_statistics_aa(self):
        ''' calculate alignment entropy of nucleotide alignments '''
        TINY = 1e-10
        if not hasattr(self, "aln"):
            print("calculate alignment first")
            return
        self.af_aa = calc_af(self.aa_aln, aa_alpha)
        is_valid = self.af_aa[:-2].sum(axis=0) > 0.5
        tmp_af = self.af_aa[:-2, is_valid] / self.af_aa[:-2,
                                                        is_valid].sum(axis=0)
        #self.entropy_aa = np.mean(-(tmp_af*np.log(tmp_af+TINY)).sum(axis=0))
        self.diversity_aa = np.mean(1.0 - (tmp_af**2).sum(axis=0))

    def mutations_to_branch(self):
        self.mut_to_branch = defaultdict(list)
        for node in self.tree.find_clades():
            if node.up is not None:
                for mut in node.mutations:
                    self.mut_to_branch[mut].append(node)

    def reduce_alignments(self, RNA_specific=False):
        if RNA_specific:
            self.aa_aln = None
            self.af_aa = None
        else:
            self.af_aa = calc_af(self.aa_aln, aa_alpha)
        for attr, aln, alpha, freq in [[
                "aln_reduced", self.aln, nuc_alpha, self.af_nuc
        ], ["aa_aln_reduced", self.aa_aln, aa_alpha, self.af_aa]]:
            try:
                if RNA_specific and attr == "aa_aln_reduced":
                    pass  #** no reduced amino alignment for RNA
                else:
                    consensus = np.array(list(alpha))[freq.argmax(axis=0)]
                    aln_array = np.array(aln)
                    aln_array[aln_array == consensus] = '.'
                    new_seqs = [
                        SeqRecord(seq=Seq("".join(consensus)),
                                  name="consensus",
                                  id="consensus")
                    ]
                    for si, seq in enumerate(aln):
                        new_seqs.append(
                            SeqRecord(seq=Seq("".join(aln_array[si])),
                                      name=seq.name,
                                      id=seq.id,
                                      description=seq.description))
                    self.__setattr__(attr, MultipleSeqAlignment(new_seqs))
            except:
                print(
                    "sf_geneCluster_align_MakeTree: aligment reduction failed")

    #def export(self, path = '', extra_attr = ['aa_muts','ann','branch_length','name','longName'], RNA_specific=False):
    def export(self,
               path='',
               extra_attr=[
                   'aa_muts', 'annotation', 'branch_length', 'name',
                   'accession'
               ],
               RNA_specific=False):
        ## write tree
        Phylo.write(self.tree, path + self.clusterID + '.nwk', 'newick')

        ## processing node name
        for node in self.tree.get_terminals():
            #node.name = node.ann.split('|')[0]
            node.accession = node.ann.split('|')[0]
            #node.longName = node.ann.split('-')[0]
            node.name = node.ann.split('-')[0]
            #NZ_CP008870|HV97_RS21955-1-fabG_3-ketoacyl-ACP_reductase
            annotation = node.ann.split('-', 2)
            if len(annotation) == 3:
                node.annotation = annotation[2]
            else:
                node.annotation = annotation[0]

        ## write tree json
        for n in self.tree.root.find_clades():
            if n.branch_length < 1e-6:
                n.branch_length = 1e-6
        timetree_fname = path + self.clusterID + '_tree.json'
        tree_json = tree_to_json(self.tree.root, extra_attr=extra_attr)
        write_json(tree_json, timetree_fname, indent=None)

        self.reduce_alignments(RNA_specific)

        ## msa compatible
        for i_aln in self.aln:
            i_aln.id = i_aln.id.replace('|', '-', 1)
        for i_alnr in self.aln_reduced:
            i_alnr.id = i_alnr.id.replace('|', '-', 1)

        AlignIO.write(self.aln, path + self.clusterID + '_na_aln.fa', 'fasta')
        AlignIO.write(self.aln_reduced,
                      path + self.clusterID + '_na_aln_reduced.fa', 'fasta')

        if RNA_specific == False:
            for i_aa_aln in self.aa_aln:
                i_aa_aln.id = i_aa_aln.id.replace('|', '-', 1)
            for i_aa_alnr in self.aa_aln_reduced:
                i_aa_alnr.id = i_aa_alnr.id.replace('|', '-', 1)

            AlignIO.write(self.aa_aln, path + self.clusterID + '_aa_aln.fa',
                          'fasta')
            AlignIO.write(self.aa_aln_reduced,
                          path + self.clusterID + '_aa_aln_reduced.fa',
                          'fasta')

        ## write seq json
        write_seq_json = 0
        if write_seq_json:
            elems = {}
            for node in self.tree.find_clades():
                if hasattr(node, "sequence"):
                    if hasattr(node, "longName") == False:
                        node.longName = node.name
                    elems[node.longName] = {}
                    nuc_dt = {
                        pos: state
                        for pos, (state, ancstate) in enumerate(
                            izip(node.sequence.tostring(),
                                 self.tree.root.sequence.tostring()))
                        if state != ancstate
                    }
                    nodeseq = node.sequence.tostring()
                    nodeseq_len = len(nodeseq)
                    elems[node.longName]['nuc'] = nuc_dt

            elems['root'] = {}
            elems['root']['nuc'] = self.tree.root.sequence.tostring()

            self.sequences_fname = path + self.clusterID + '_seq.json'
            write_json(elems, self.sequences_fname, indent=None)
class mpm_tree(object):
    '''
    class that aligns a set of sequences and infers a tree
    '''

    def __init__(self, cluster_seq_filepath, **kwarks):
        self.clusterID= cluster_seq_filepath.split('/')[-1].split('.fna')[0]
        if 'speciesID' in kwarks:
            folderID=kwarks['speciesID']
        else:
            folderID= cluster_seq_filepath.split('/')[-3]
        self.seqs = {x.id:x for x in SeqIO.parse(cluster_seq_filepath, 'fasta')}
        if 'run_dir' not in kwarks:
            import random
            #self.run_dir = '_'.join(['tmp', self.clusterID])
            self.run_dir = 'tmp/'
            self.run_dir += '_'.join([folderID, 'tmp', time.strftime('%H%M%S',time.gmtime()), str(random.randint(0,100000000))])
        else:
            self.run_dir = kwarks['run_dir']
        self.nuc=True

    def codon_align(self, alignment_tool="mafft", prune=True, discard_premature_stops=False):
        '''
        takes a nucleotide alignment, translates it, aligns the amino acids, pads the gaps
        note that this suppresses any compensated frameshift mutations

        Parameters:
        - alignment_tool: ['mafft', 'muscle'] the commandline tool to use
        '''
        cwd = os.getcwd()
        make_dir(self.run_dir)
        os.chdir(self.run_dir)

        # translate
        aa_seqs = {}
        for seq in self.seqs.values():
            tempseq = seq.seq.translate(table="Bacterial")
            # use only sequences that translate without trouble
            if not discard_premature_stops or '*' not in str(tempseq)[:-1] or prune==False:
                aa_seqs[seq.id]=SeqRecord(tempseq,id=seq.id)
            else:
                print(seq.id,"has premature stops, discarding")

        tmpfname = 'temp_in.fasta'
        SeqIO.write(aa_seqs.values(), tmpfname,'fasta')

        if alignment_tool=='mafft':
            os.system('mafft --reorder --amino temp_in.fasta 1> temp_out.fasta')
            aln_aa = AlignIO.read('temp_out.fasta', "fasta")
        elif alignment_tool=='muscle':
            from Bio.Align.Applications import MuscleCommandline
            cline = MuscleCommandline(input=tmpfname, out=tmpfname[:-5]+'aligned.fasta')
            cline()
            aln_aa = AlignIO.read(tmpfname[:-5]+'aligned.fasta', "fasta")
        else:
            print 'Alignment tool not supported:'+alignment_tool
            #return

        #generate nucleotide alignment
        self.aln = pad_nucleotide_sequences(aln_aa, self.seqs)
        os.chdir(cwd)
        remove_dir(self.run_dir)

    def align(self):
        '''
        align sequencences in self.seqs using mafft
        '''
        cwd = os.getcwd()
        make_dir(self.run_dir)
        os.chdir(self.run_dir)

        SeqIO.write(self.seqs.values(), "temp_in.fasta", "fasta")
        os.system('mafft --reorder --anysymbol temp_in.fasta 1> temp_out.fasta 2> mafft.log')

        self.aln = AlignIO.read('temp_out.fasta', 'fasta')
        os.chdir(cwd)
        remove_dir(self.run_dir)


    def build(self, root='midpoint', raxml=True, fasttree_program='fasttree', raxml_time_limit=0.5, treetime_used=True):
        '''
        build a phylogenetic tree using fasttree and raxML (optional)
        based on nextflu tree building pipeline
        '''
        import subprocess
        cwd = os.getcwd()
        make_dir(self.run_dir)
        os.chdir(self.run_dir)
        AlignIO.write(self.aln, 'origin.fasta', 'fasta')
        name_translation = make_strains_unique(self.aln)
        AlignIO.write(self.aln, 'temp.fasta', 'fasta')

        tree_cmd = [fasttree_program]
        if self.nuc: tree_cmd.append("-nt")
        tree_cmd.append("temp.fasta 1> initial_tree.newick 2> fasttree.log")
        os.system(" ".join(tree_cmd))

        out_fname = "tree_infer.newick"

        if raxml==False:
            #shutil.copy('initial_tree.newick', out_fname)
            polytomies_midpointRooting('initial_tree.newick',out_fname, self.clusterID)
        elif len(set([x.id for x in SeqIO.parse('temp.fasta', 'fasta')]))>3:
        ## only for tree with >3 strains
            if raxml_time_limit>0:
                tmp_tree = Phylo.read('initial_tree.newick','newick')
                resolve_iter = 0
                resolve_polytomies(tmp_tree)
                while (not tmp_tree.is_bifurcating()) and (resolve_iter<10):
                    resolve_iter+=1
                    resolve_polytomies(tmp_tree)
                Phylo.write(tmp_tree,'initial_tree.newick', 'newick')
                AlignIO.write(self.aln,"temp.phyx", "phylip-relaxed")
                print( "RAxML tree optimization with time limit", raxml_time_limit,  "hours")
                # using exec to be able to kill process
                end_time = time.time() + int(raxml_time_limit*3600)
                process = subprocess.Popen("exec raxml -f d -j -s temp.phyx -n topology -c 25 -m GTRCAT -p 344312987 -t initial_tree.newick", shell=True)
                while (time.time() < end_time):
                    if os.path.isfile('RAxML_result.topology'):
                        break
                    time.sleep(10)
                process.terminate()

                checkpoint_files = glob.glob("RAxML_checkpoint*")
                if os.path.isfile('RAxML_result.topology'):
                    checkpoint_files.append('RAxML_result.topology')
                if len(checkpoint_files) > 0:
                    last_tree_file = checkpoint_files[-1]
                    shutil.copy(last_tree_file, 'raxml_tree.newick')
                else:
                    shutil.copy("initial_tree.newick", 'raxml_tree.newick')
            else:
                shutil.copy("initial_tree.newick", 'raxml_tree.newick')

            try:
                print("RAxML branch length optimization")
                os.system("raxml -f e -s temp.phyx -n branches -c 25 -m GTRGAMMA -p 344312987 -t raxml_tree.newick")
                shutil.copy('RAxML_result.branches', out_fname)
            except:
                print("RAxML branch length optimization failed")
                shutil.copy('raxml_tree.newick', out_fname)

        if treetime_used:
            # load the resulting tree as a treetime instance
            from treetime import TreeAnc
            self.tt = TreeAnc(tree=out_fname, aln=self.aln, gtr='Jukes-Cantor', verbose=0)
            # provide short cut to tree and revert names that conflicted with newick format
            self.tree = self.tt.tree
        else:
            self.tree = Phylo.read(out_fname,'newick')
        self.tree.root.branch_length=0.0001
        restore_strain_name(name_translation, self.aln)
        restore_strain_name(name_translation, self.tree.get_terminals())

        for node in self.tree.find_clades():
            if node.name is not None:
                if node.name.startswith('NODE_')==False:
                    node.ann=node.name
            else:
                node.name='NODE_0'

        os.chdir(cwd)
        remove_dir(self.run_dir)
        self.is_timetree=False

    def ancestral(self, translate_tree = False):
        '''
        infer ancestral nucleotide sequences using maximum likelihood
        and translate the resulting sequences (+ terminals) to amino acids
        '''
        try:
            self.tt.reconstruct_anc(method='ml')
        except:
            print "trouble at self.tt.reconstruct_anc(method='ml')"

        if translate_tree:
            for node in self.tree.find_clades():
                node.aa_sequence = np.fromstring(str(self.translate_seq("".join(node.sequence))), dtype='S1')

    def refine(self, CDS = True):
        '''
        determine mutations on each branch and attach as string to the branches
        '''
        for node in self.tree.find_clades():

            if node.up is not None:
                node.muts = ",".join(["".join(map(str, x)) for x in node.mutations if '-' not in x])
                if CDS == True:
                    node.aa_muts = ",".join([anc+str(pos+1)+der for pos, (anc, der)
                                    in enumerate(zip(node.up.aa_sequence, node.aa_sequence))
                                    if anc!=der and '-' not in anc and '-' not in der])


    def translate_seq(self, seq):
        '''
        custom translation sequence that handles gaps
        '''
        if type(seq) not in [str, unicode]:
            str_seq = str(seq.seq)
        else:
            str_seq = seq
        try:
            # soon not needed as future biopython version will translate --- into -
            tmp_seq = Seq(str(Seq(str_seq.replace('---', 'NNN')).translate(table="Bacterial")).replace('X','-'))
        except:
            tmp_seq = Seq(str(Seq(str_seq.replace('-', 'N')).translate(table="Bacterial")).replace('X','-'))
        return tmp_seq

    def translate(self):
        '''
        translate the nucleotide alignment to an amino acid alignment
        '''
        aa_seqs = []
        for seq in self.aln:
            aa_seqs.append(SeqRecord(seq=self.translate_seq(seq), id=seq.id,
                                     name=seq.name, description=seq.description))
        self.aa_aln = MultipleSeqAlignment(aa_seqs)

    def mean_std_seqLen(self):
        """ returen mean and standard deviation of sequence lengths """
        seqLen_arr = np.array([ len(seq) for seq in self.seqs.values()])
        return np.mean(seqLen_arr, axis=0), np.std(seqLen_arr, axis=0)

    def paralogy_statistics(self):
        best_split = find_best_split(self.tree)
        return len(best_split.para_nodes), best_split.branch_length

    def diversity_statistics_nuc(self):
        ''' calculate alignment entropy of nucleotide alignments '''
        TINY = 1e-10
        if not hasattr(self, "aln"):
            print("calculate alignment first")
            return
        self.af_nuc = calc_af(self.aln, nuc_alpha)
        is_valid = self.af_nuc[:-2].sum(axis=0)>0.5
        tmp_af = self.af_nuc[:-2,is_valid]/self.af_nuc[:-2,is_valid].sum(axis=0)
        #self.entropy_nuc = np.mean(-(tmp_af*np.log(tmp_af+TINY)).sum(axis=0))
        self.diversity_nuc = np.mean(1.0-(tmp_af**2).sum(axis=0))

    def diversity_statistics_aa(self):
        ''' calculate alignment entropy of nucleotide alignments '''
        TINY = 1e-10
        if not hasattr(self, "aln"):
            print("calculate alignment first")
            return
        self.af_aa = calc_af(self.aa_aln, aa_alpha)
        is_valid = self.af_aa[:-2].sum(axis=0)>0.5
        tmp_af = self.af_aa[:-2,is_valid]/self.af_aa[:-2,is_valid].sum(axis=0)
        #self.entropy_aa = np.mean(-(tmp_af*np.log(tmp_af+TINY)).sum(axis=0))
        self.diversity_aa = np.mean(1.0-(tmp_af**2).sum(axis=0))

    def mutations_to_branch(self):
        self.mut_to_branch = defaultdict(list)
        for node in self.tree.find_clades():
            if node.up is not None:
                for mut in node.mutations:
                    self.mut_to_branch[mut].append(node)

    def reduce_alignments(self,RNA_specific=False):
        if RNA_specific:
            self.aa_aln=None
            self.af_aa =None
        else:
            self.af_aa= calc_af(self.aa_aln, aa_alpha)
        for attr, aln, alpha, freq in [["aln_reduced", self.aln, nuc_alpha, self.af_nuc],
                                 ["aa_aln_reduced", self.aa_aln, aa_alpha, self.af_aa]]:
            try:
                if RNA_specific and attr=="aa_aln_reduced":
                    pass #** no reduced amino alignment for RNA
                else:
                    consensus = np.array(list(alpha))[freq.argmax(axis=0)]
                    aln_array = np.array(aln)
                    aln_array[aln_array==consensus]='.'
                    new_seqs = [SeqRecord(seq=Seq("".join(consensus)), name="consensus", id="consensus")]
                    for si, seq in enumerate(aln):
                        new_seqs.append(SeqRecord(seq=Seq("".join(aln_array[si])), name=seq.name,
                                           id=seq.id, description=seq.description))
                    self.__setattr__(attr, MultipleSeqAlignment(new_seqs))
            except:
                print("sf_geneCluster_align_MakeTree: aligment reduction failed")


    #def export(self, path = '', extra_attr = ['aa_muts','ann','branch_length','name','longName'], RNA_specific=False):
    def export(self, path = '', extra_attr = ['aa_muts','annotation','branch_length','name','accession'], RNA_specific=False):
        ## write tree
        Phylo.write(self.tree, path+self.clusterID+'.nwk', 'newick')

        ## processing node name
        for node in self.tree.get_terminals():
            #node.name = node.ann.split('|')[0]
            node.accession = node.ann.split('|')[0]
            #node.longName = node.ann.split('-')[0]
            node.name = node.ann.split('-')[0]
            #NZ_CP008870|HV97_RS21955-1-fabG_3-ketoacyl-ACP_reductase
            annotation=node.ann.split('-',2)
            if len(annotation)==3:
                node.annotation= annotation[2]
            else:
                node.annotation= annotation[0]

        ## write tree json
        for n in self.tree.root.find_clades():
            if n.branch_length<1e-6:
                n.branch_length = 1e-6
        timetree_fname = path+self.clusterID+'_tree.json'
        tree_json = tree_to_json(self.tree.root, extra_attr=extra_attr)
        write_json(tree_json, timetree_fname, indent=None)

        self.reduce_alignments(RNA_specific)

        ## msa compatible
        for i_aln in self.aln:
            i_aln.id=i_aln.id.replace('|','-',1)
        for i_alnr in self.aln_reduced:
            i_alnr.id=i_alnr.id.replace('|','-',1)

        AlignIO.write(self.aln, path+self.clusterID+'_na_aln.fa', 'fasta')
        AlignIO.write(self.aln_reduced, path+self.clusterID+'_na_aln_reduced.fa', 'fasta')

        if RNA_specific==False:
            for i_aa_aln in self.aa_aln:
                i_aa_aln.id=i_aa_aln.id.replace('|','-',1)
            for i_aa_alnr in self.aa_aln_reduced:
                i_aa_alnr.id=i_aa_alnr.id.replace('|','-',1)

            AlignIO.write(self.aa_aln, path+self.clusterID+'_aa_aln.fa', 'fasta')
            AlignIO.write(self.aa_aln_reduced, path+self.clusterID+'_aa_aln_reduced.fa', 'fasta')

        ## write seq json
        write_seq_json=0
        if write_seq_json:
            elems = {}
            for node in self.tree.find_clades():
                if hasattr(node, "sequence"):
                    if hasattr(node, "longName")==False:
                        node.longName=node.name
                    elems[node.longName] = {}
                    nuc_dt= {pos:state for pos, (state, ancstate) in
                                    enumerate(izip(node.sequence.tostring(), self.tree.root.sequence.tostring())) if state!=ancstate}
                    nodeseq=node.sequence.tostring();nodeseq_len=len(nodeseq)
                    elems[node.longName]['nuc']=nuc_dt

            elems['root'] = {}
            elems['root']['nuc'] = self.tree.root.sequence.tostring()

            self.sequences_fname=path+self.clusterID+'_seq.json'
            write_json(elems, self.sequences_fname, indent=None)