Exemple #1
0
class tree(object):
    """tree builds a phylgenetic tree from an alignment and exports it for web visualization"""
    def __init__(self, aln, proteins=None, **kwarks):
        super(tree, self).__init__()
        self.aln = aln
        self.nthreads = 2
        self.sequence_lookup = {seq.id: seq for seq in aln}
        self.nuc = kwarks['nuc'] if 'nuc' in kwarks else True
        self.dump_attr = []
        if proteins != None:
            self.proteins = proteins
        else:
            self.proteins = {}
        if 'run_dir' not in kwarks:
            import random
            self.run_dir = '_'.join([
                'temp',
                time.strftime('%Y%m%d-%H%M%S', time.gmtime()),
                str(random.randint(0, 1000000))
            ])
        else:
            self.run_dir = kwarks['run_dir']

    def dump(self, treefile, nodefile):
        from Bio import Phylo
        Phylo.write(self.tree, treefile, 'newick')
        node_props = {}
        for node in self.tree.find_clades():
            node_props[node.name] = {
                attr: node.__getattribute__(attr)
                for attr in self.dump_attr if hasattr(node, attr)
            }

        with myopen(nodefile, 'w') as nfile:
            from cPickle import dump
            dump(node_props, nfile)

    def build(self,
              root='midpoint',
              raxml=True,
              raxml_time_limit=0.5,
              raxml_bin='raxml'):
        from Bio import Phylo, AlignIO
        import subprocess, glob, shutil
        make_dir(self.run_dir)
        os.chdir(self.run_dir)
        for seq in self.aln:
            seq.name = seq.id
        AlignIO.write(self.aln, 'temp.fasta', 'fasta')

        tree_cmd = ["fasttree"]
        if self.nuc: tree_cmd.append("-nt")
        tree_cmd.append("temp.fasta")
        tree_cmd.append(">")
        tree_cmd.append("initial_tree.newick")
        os.system(" ".join(tree_cmd))

        out_fname = "tree_infer.newick"
        if raxml:
            if raxml_time_limit > 0:
                tmp_tree = Phylo.read('initial_tree.newick', 'newick')
                resolve_iter = 0
                resolve_polytomies(tmp_tree)
                while (not tmp_tree.is_bifurcating()) and (resolve_iter < 10):
                    resolve_iter += 1
                    resolve_polytomies(tmp_tree)
                Phylo.write(tmp_tree, 'initial_tree.newick', 'newick')
                AlignIO.write(self.aln, "temp.phyx", "phylip-relaxed")
                print("RAxML tree optimization with time limit",
                      raxml_time_limit, "hours")
                # using exec to be able to kill process
                end_time = time.time() + int(raxml_time_limit * 3600)
                process = subprocess.Popen(
                    "exec " + raxml_bin + " -f d -T " + str(self.nthreads) +
                    " -j -s temp.phyx -n topology -c 25 -m GTRCAT -p 344312987 -t initial_tree.newick",
                    shell=True)
                while (time.time() < end_time):
                    if os.path.isfile('RAxML_result.topology'):
                        break
                    time.sleep(10)
                process.terminate()

                checkpoint_files = glob.glob("RAxML_checkpoint*")
                if os.path.isfile('RAxML_result.topology'):
                    checkpoint_files.append('RAxML_result.topology')
                if len(checkpoint_files) > 0:
                    last_tree_file = checkpoint_files[-1]
                    shutil.copy(last_tree_file, 'raxml_tree.newick')
                else:
                    shutil.copy("initial_tree.newick", 'raxml_tree.newick')
            else:
                shutil.copy("initial_tree.newick", 'raxml_tree.newick')

            try:
                print("RAxML branch length optimization")
                os.system(
                    raxml_bin + " -f e -T " + str(self.nthreads) +
                    " -s temp.phyx -n branches -c 25 -m GTRGAMMA -p 344312987 -t raxml_tree.newick"
                )
                shutil.copy('RAxML_result.branches', out_fname)
            except:
                print("RAxML branch length optimization failed")
                shutil.copy('raxml_tree.newick', out_fname)
        else:
            shutil.copy('initial_tree.newick', out_fname)
        self.tt_from_file(out_fname, root)
        os.chdir('..')
        remove_dir(self.run_dir)

    def tt_from_file(self, infile, root='best', nodefile=None):
        from treetime import TreeTime
        from treetime import utils
        self.is_timetree = False
        print('Reading tree from file', infile)
        dates = {
            seq.id: seq.attributes['num_date']
            for seq in self.aln if 'date' in seq.attributes
        }
        self.tt = TreeTime(dates=dates,
                           tree=infile,
                           gtr='Jukes-Cantor',
                           aln=self.aln,
                           verbose=4)
        if root:
            self.tt.reroot(root=root)
        self.tree = self.tt.tree

        for node in self.tree.find_clades():
            if node.is_terminal() and node.name in self.sequence_lookup:
                seq = self.sequence_lookup[node.name]
                node.attr = seq.attributes
                try:
                    node.attr['date'] = node.attr['date'].strftime('%Y-%m-%d')
                except:
                    pass
            else:
                node.attr = {}

        if nodefile is not None:
            print('reading node properties from file:', nodefile)
            with myopen(nodefile, 'r') as infile:
                from cPickle import load
                node_props = load(infile)
            for n in self.tree.find_clades():
                if n.name in node_props:
                    for attr in node_props[n.name]:
                        n.__setattr__(attr, node_props[n.name][attr])
                else:
                    print("No node properties found for ", n.name)

    def ancestral(self, **kwarks):
        self.tt.optimize_seq_and_branch_len(infer_gtr=True, **kwarks)
        self.dump_attr.append('sequence')
        for node in self.tree.find_clades():
            if not hasattr(node, 'attr'):
                node.attr = {}

    def timetree(self,
                 Tc=0.01,
                 infer_gtr=True,
                 reroot='best',
                 resolve_polytomies=True,
                 max_iter=2,
                 **kwarks):
        self.tt.run(infer_gtr=infer_gtr,
                    root=reroot,
                    Tc=Tc,
                    resolve_polytomies=resolve_polytomies,
                    max_iter=max_iter)
        print('estimating time tree...')
        self.dump_attr.extend(['numdate', 'date', 'sequence'])
        for node in self.tree.find_clades():
            if hasattr(node, 'attr'):
                node.attr['num_date'] = node.numdate
            else:
                node.attr = {'num_date': node.numdate}
        self.is_timetree = True

    def geo_inference(self, attr):
        '''
        infer a "mugration" model by pretending each region corresponds to a sequence
        state and repurposing the GTR inference and ancestral reconstruction
        '''
        from treetime.gtr import GTR
        # Determine alphabet and store reconstructed ancestral sequences
        places = set()
        nuc_seqs = {}
        nuc_muts = {}
        nuc_seq_LH = None
        if hasattr(self.tt.tree, 'sequence_LH'):
            nuc_seq_LH = self.tt.tree.sequence_LH
        for node in self.tree.find_clades():
            if hasattr(node, 'attr'):
                if attr in node.attr:
                    places.add(node.attr[attr])
            if hasattr(node, 'sequence'):
                nuc_seqs[node] = node.sequence
            if hasattr(node, 'mutations'):
                nuc_muts[node] = node.mutations
                node.__delattr__('mutations')

        # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45)
        places = sorted(places)
        nc = len(places)
        if nc < 2 or nc > 180:
            print(
                "geo_inference: can't have less than 2 or more than 180 places!"
            )
            return

        alphabet = {chr(65 + i): place for i, place in enumerate(places)}
        alphabet_rev = {v: k for k, v in alphabet.iteritems()}
        sequence_gtr = self.tt.gtr
        myGeoGTR = GTR.custom(pi=np.ones(nc, dtype=float) / nc,
                              W=np.ones((nc, nc)),
                              alphabet=np.array(sorted(alphabet.keys())))
        myGeoGTR.profile_map['-'] = np.ones(nc)

        # set geo info to nodes as one letter sequence.
        for node in self.tree.get_terminals():
            if hasattr(node, 'attr'):
                if attr in node.attr:
                    node.sequence = np.array([alphabet_rev[node.attr[attr]]])
            else:
                node.sequence = np.array(['-'])
        for node in self.tree.get_nonterminals():
            node.__delattr__('sequence')
        # set custom GTR model, run inference
        self.tt._gtr = myGeoGTR
        tmp_use_mutation_length = self.tt.use_mutation_length
        self.tt.use_mutation_length = False
        self.tt.infer_ancestral_sequences(method='ml',
                                          infer_gtr=True,
                                          store_compressed=False,
                                          pc=5.0,
                                          marginal=True)

        # restore the nucleotide sequence and mutations to maintain expected behavior
        self.tt.geogtr = self.tt.gtr
        self.tt.geogtr.alphabet_to_location = alphabet
        self.tt._gtr = sequence_gtr
        self.dump_attr.append(attr)
        if hasattr(self.tt.tree, 'sequence_LH'):
            self.tt.tree.geo_LH = self.tt.tree.sequence_LH
            self.tt.tree.sequence_LH = nuc_seq_LH
        for node in self.tree.find_clades():
            node.attr[attr] = alphabet[node.sequence[0]]
            if node in nuc_seqs:
                node.sequence = nuc_seqs[node]
            if node.up is not None:
                node.__setattr__(attr + '_transitions', node.mutations)
                if node in nuc_muts:
                    node.mutations = nuc_muts[node]

        self.tt.use_mutation_length = tmp_use_mutation_length

    def get_attr_list(self, get_attr):
        states = []
        for node in self.tree.find_clades():
            if get_attr in node.attr:
                states.append(node.attr[get_attr])
        return states

    def add_translations(self):
        '''
        translate the nucleotide sequence into the proteins specified
        in self.proteins. these are expected to be SeqFeatures
        '''
        from Bio import Seq
        for node in self.tree.find_clades(order='preorder'):
            if not hasattr(node, "translations"):
                node.translations = {}
                node.aa_mutations = {}
            if node.up is None:
                for prot in self.proteins:
                    node.translations[prot] = Seq.translate(
                        str(self.proteins[prot].extract(
                            Seq.Seq("".join(node.sequence)))).replace(
                                '-', 'N'))
                    node.aa_mutations[prot] = []
            else:
                for prot in self.proteins:
                    node.translations[prot] = Seq.translate(
                        str(self.proteins[prot].extract(
                            Seq.Seq("".join(node.sequence)))).replace(
                                '-', 'N'))
                    node.aa_mutations[prot] = [
                        (a, pos, d) for pos, (a, d) in enumerate(
                            zip(node.up.translations[prot],
                                node.translations[prot])) if a != d
                    ]
        self.dump_attr.append('translations')

    def refine(self):
        '''
        add attributes for export, currently this is only muts and aa_muts
        '''
        self.tree.ladderize()
        for node in self.tree.find_clades():
            if node.up is not None:
                node.muts = [
                    "".join(map(str, [a, pos + 1, d]))
                    for a, pos, d in node.mutations
                ]
                node.aa_muts = {}
                if hasattr(node, 'translations'):
                    for prot in node.translations:
                        node.aa_muts[prot] = [
                            "".join(map(str, [a, pos + 1, d]))
                            for a, pos, d in node.aa_mutations[prot]
                        ]
        for node in self.tree.find_clades(order="preorder"):
            if node.up is not None:  #try:
                node.attr["div"] = node.up.attr["div"] + node.mutation_length
            else:
                node.attr["div"] = 0
        self.dump_attr.extend([
            'muts', 'aa_muts', 'aa_mutations', 'mutation_length', 'mutations'
        ])

    def layout(self):
        """Add clade, xvalue, yvalue, mutation and trunk attributes to all nodes in tree"""
        clade = 0
        yvalue = self.tree.count_terminals()
        for node in self.tree.find_clades(order="preorder"):
            node.clade = clade
            clade += 1
            if node.up is not None:  #try:
                node.xvalue = node.up.xvalue + node.mutation_length
                if self.is_timetree:
                    node.tvalue = node.numdate - self.tree.root.numdate
                else:
                    node.tvalue = 0
            else:
                node.xvalue = 0
                node.tvalue = 0
            if node.is_terminal():
                node.yvalue = yvalue
                yvalue -= 1
        for node in self.tree.get_nonterminals(order="postorder"):
            node.yvalue = np.mean([x.yvalue for x in node.clades])
        self.dump_attr.extend(['yvalue', 'xvalue', 'clade'])
        if self.is_timetree:
            self.dump_attr.extend(['tvalue'])

    def export(self,
               path='',
               extra_attr=['aa_muts', 'clade'],
               plain_export=10,
               indent=None):
        '''
        export the tree data structure along with the sequence information as
        json files for display in web browsers.
        parameters:
            path    -- path (incl prefix) to which the output files are written.
                       filenames themselves are standardized  to *tree.json and *sequences.json
            extra_attr -- attributes of tree nodes that are exported to json
            plain_export -- store sequences are plain strings instead of
                            differences to root if number of differences exceeds
                            len(seq)/plain_export
        '''
        from Bio import Seq
        from itertools import izip
        timetree_fname = path + 'tree.json'
        sequence_fname = path + 'sequences.json'
        tree_json = tree_to_json(self.tree.root, extra_attr=extra_attr)
        write_json(tree_json, timetree_fname, indent=indent)

        # prepare a json with sequence information to export.
        # first step: add the sequence & translations of the root as string
        elems = {}
        elems['root'] = {}
        elems['root']['nuc'] = "".join(self.tree.root.sequence)
        for prot, seq in self.tree.root.translations.iteritems():
            elems['root'][prot] = seq

        # add sequence for every node in tree. code as difference to root
        # or as full strings.
        for node in self.tree.find_clades():
            if hasattr(node, "clade"):
                elems[node.clade] = {}
                # loop over proteins and nucleotide sequences
                for prot, seq in [('nuc', "".join(node.sequence))
                                  ] + node.translations.items():
                    differences = {
                        pos: state
                        for pos, (state, ancstate) in enumerate(
                            izip(seq, elems['root'][prot]))
                        if state != ancstate
                    }
                    if plain_export * len(differences) <= len(seq):
                        elems[node.clade][prot] = differences
                    else:
                        elems[node.clade][prot] = seq

        write_json(elems, sequence_fname, indent=indent)
Exemple #2
0
class tree(object):
    """tree builds a phylgenetic tree from an alignment and exports it for web visualization"""
    def __init__(self, aln, proteins=None, verbose=2, logger=None, **kwarks):
        super(tree, self).__init__()
        self.aln = aln
        # self.nthreads = 2
        self.sequence_lookup = {seq.id:seq for seq in aln}
        self.nuc = kwarks['nuc'] if 'nuc' in kwarks else True
        self.dump_attr = [] # depreciated
        self.verbose = verbose
        if proteins!=None:
            self.proteins = proteins
        else:
            self.proteins={}
        if 'run_dir' not in kwarks:
            import random
            self.run_dir = '_'.join(['temp', time.strftime('%Y%m%d-%H%M%S',time.gmtime()), str(random.randint(0,1000000))])
        else:
            self.run_dir = kwarks['run_dir']
        if logger is None:
            def f(x,y):
                if y<self.verbose: print(x)
            self.logger = f
        else:
            self.logger=logger

    def getDateMin(self):
        return self.tree.root.date

    def getDateMax(self):
        dateMax = self.tree.root.date
        for node in self.tree.find_clades():
            if node.date > dateMax:
                dateMax = node.date
        return dateMax

    def dump(self, treefile, nodefile):
        from Bio import Phylo
        Phylo.write(self.tree, treefile, 'newick')
        node_props = {}
        for node in self.tree.find_clades():
            node_props[node.name] = {attr:node.__getattribute__(attr) for attr in self.dump_attr if hasattr(node, attr)}

        with myopen(nodefile, 'w') as nfile:
            pickle.dump(node_props, nfile)

    def check_newick(self, newick_file):
        try:
            tree = Phylo.parse(newick_file, 'newick').next()
            assert(set([x.name for x in tree.get_terminals()]) == set(self.sequence_lookup.keys()))
            return True
        except:
            return False

    def build_newick(self, newick_file, nthreads=2, method="raxml", raxml_options={},
                     iqtree_options={}, debug=False):
        make_dir(self.run_dir)
        os.chdir(self.run_dir)
        for seq in self.aln: seq.name=seq.id
        out_fname = os.path.join("..", newick_file)
        if method=="raxml":
            self.build_newick_raxml(out_fname, nthreads=nthreads, **raxml_options)
        elif method=="fasttree":
            self.build_newick_fasttree(out_fname)
        elif method=="iqtree":
            self.build_newick_iqtree(out_fname, **iqtree_options)
        os.chdir('..')
        self.logger("Saved new tree to %s"%out_fname, 1)
        if not debug:
            remove_dir(self.run_dir)


    def build_newick_fasttree(self, out_fname):
        from Bio import Phylo, AlignIO
        AlignIO.write(self.aln, 'temp.fasta', 'fasta')
        self.logger("Building tree with fasttree", 1)
        tree_cmd = ["fasttree"]
        if self.nuc: tree_cmd.append("-nt")

        tree_cmd.extend(["temp.fasta","1>",out_fname, "2>", "fasttree_stderr"])
        os.system(" ".join(tree_cmd))


    def build_newick_raxml(self, out_fname, nthreads=2, raxml_bin="raxml",
                           num_distinct_starting_trees=1, **kwargs):
        from Bio import Phylo, AlignIO
        import shutil
        self.logger("modified RAxML script - no branch length optimisation or time limit", 1)
        AlignIO.write(self.aln,"temp.phyx", "phylip-relaxed")
        if num_distinct_starting_trees == 1:
            cmd = raxml_bin + " -f d -T " + str(nthreads) + " -m GTRCAT -c 25 -p 235813 -n tre -s temp.phyx"
        else:
            self.logger("RAxML running with {} starting trees (longer but better...)".format(num_distinct_starting_trees), 1)
            cmd = raxml_bin + " -f d -T " + str(nthreads) + " -N " + str(num_distinct_starting_trees) + " -m GTRCAT -c 25 -p 235813 -n tre -s temp.phyx"

        try:
            with open("raxml.log", 'w') as fh:
                check_call(cmd, stdout=fh, stderr=STDOUT, shell=True)
                self.logger("RAXML COMPLETED.", 1)
        except CalledProcessError:
            self.logger("RAXML TREE FAILED - check {}/raxml.log".format(self.run_dir), 1)
            raise
        shutil.copy('RAxML_bestTree.tre', out_fname)


    def build_newick_iqtree(self, out_fname, nthreads=2, iqtree_bin="iqtree",
                            iqmodel="HKY",  **kwargs):
        from Bio import Phylo, AlignIO
        import shutil
        self.logger("modified RAxML script - no branch length optimisation or time limit", 1)
        aln_file = "temp.fasta"
        AlignIO.write(self.aln, aln_file, "fasta")
        with open(aln_file) as ifile:
            tmp_seqs = ifile.readlines()
        with open(aln_file, 'w') as ofile:
            for line in tmp_seqs:
                ofile.write(line.replace('/', '_X_X_'))

        if iqmodel:
            call = ["iqtree", "-nt", str(nthreads), "-s", aln_file, "-m", iqmodel, "-fast",
                ">", "iqtree.log"]
        else:
            call = ["iqtree", "-nt", str(nthreads), "-s", aln_file, ">", "iqtree.log"]

        os.system(" ".join(call))
        T = Phylo.read(aln_file+".treefile", 'newick')
        for n in T.get_terminals():
            n.name = n.name.replace('_X_X_','/')
        Phylo.write(T,out_fname, 'newick')


    def tt_from_file(self, infile, root='best', nodefile=None):
        self.is_timetree=False
        self.logger('Reading tree from file '+infile,2)
        dates  =   {seq.id:seq.attributes['num_date']
                    for seq in self.aln if 'date' in seq.attributes}
        self.tt = TreeTime(dates=dates, tree=str(infile), gtr='Jukes-Cantor',
                            aln = self.aln, verbose=self.verbose, fill_overhangs=True)
        if root:
            self.tt.reroot(root=root)
        self.tree = self.tt.tree

        for node in self.tree.find_clades():
            if node.is_terminal() and node.name in self.sequence_lookup:
                seq = self.sequence_lookup[node.name]
                node.attr = seq.attributes
                try:
                    node.attr['date'] = node.attr['date'].strftime('%Y-%m-%d')
                except:
                    pass
            else:
                node.attr = {}

        if nodefile is not None:
            self.logger('reading node properties from file: '+nodefile,2)
            with myopen(nodefile, 'r') as infile:
                node_props = pickle.load(infile)
            for n in self.tree.find_clades():
                if n.name in node_props:
                    for attr in node_props[n.name]:
                        n.__setattr__(attr, node_props[n.name][attr])
                else:
                    self.logger("No node properties found for "+n.name,2)


    def ancestral(self, **kwarks):
        self.tt.optimize_seq_and_branch_len(infer_gtr=True, **kwarks)
        self.dump_attr.append('sequence')
        for node in self.tree.find_clades():
            if not hasattr(node,'attr'):
                node.attr = {}

    ## TODO REMOVE KWARKS - MAKE EXPLICIT
    def timetree(self, Tc=0.01, infer_gtr=True, reroot='best', resolve_polytomies=True,
                 max_iter=2, confidence=False, use_marginal=False, **kwarks):
        self.logger('estimating time tree...',2)
        if confidence and use_marginal:
            marginal = 'assign'
        else:
            marginal = confidence
        self.tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal,
                    resolve_polytomies=resolve_polytomies, max_iter=max_iter, **kwarks)
        self.logger('estimating time tree...done',3)
        self.dump_attr.extend(['numdate','date','sequence'])
        to_numdate = self.tt.date2dist.to_numdate
        for node in self.tree.find_clades():
            if hasattr(node,'attr'):
                node.attr['num_date'] = node.numdate
            else:
                node.attr = {'num_date':node.numdate}
            if confidence:
                node.attr["num_date_confidence"] = sorted(self.tt.get_max_posterior_region(node, fraction=0.9))

        self.is_timetree=True


    def save_timetree(self, fprefix, ttopts, cfopts):
        Phylo.write(self.tt.tree, fprefix+"_timetree.new", "newick")
        n = {}
        attrs = ["branch_length", "mutation_length", "clock_length", "dist2root",
                 "name", "mutations", "attr", "cseq", "sequence", "numdate"]
        for node in self.tt.tree.find_clades():
            n[node.name] = {}
            for x in attrs:
                n[node.name][x] = getattr(node, x)
        with open(fprefix+"_timetree.pickle", 'wb') as fh:
            pickle.dump({
                "timetree_options": ttopts,
                "clock_filter_options": cfopts,
                "nodes": n,
                "original_seqs": list(self.sequence_lookup.keys()),
            }, fh, protocol=pickle.HIGHEST_PROTOCOL)

    def restore_timetree_node_info(self, nodes):
        for node in self.tt.tree.find_clades():
            info = nodes[node.name]
            # print("restoring node info for node ", node.name)
            for k, v in info.items():
                setattr(node, k, v)
        self.is_timetree=True


    def remove_outlier_clades(self, max_nodes=3, min_length=0.03):
        '''
        check whether one child clade of the root is small and the connecting branch
        is long. if so, move the root up and reset a few tree props
        Args:
            max_nodes   number of nodes beyond which the outliers are note removed
            min_length  minimal length of the branch connecting the outlier clade
                        to the rest of the tree to allow cutting.
        Returns:
            list of names of strains that have been removed
        '''
        R = self.tt.tree.root
        if len(R.clades)>2:
            return

        num_child_nodes = np.array([c.count_terminals() for c in R])
        putative_outlier = num_child_nodes.argmin()
        bl = np.sum([c.branch_length for c in R])
        if (bl>min_length and num_child_nodes[putative_outlier]<max_nodes):
            if num_child_nodes[putative_outlier]==1:
                print("removing \"{}\" which is an outlier clade".format(R.clades[putative_outlier].name))
            else:
                print("removing {} isolates which formed an outlier clade".format(num_child_nodes[putative_outlier]))
            self.tt.tree.root = R.clades[(1+putative_outlier)%2]
        self.tt.prepare_tree()
        return [c.name for c in R.clades[putative_outlier].get_terminals()]


    def geo_inference(self, attr, missing='?', root_state=None, report_confidence=False):
        '''
        infer a "mugration" model by pretending each region corresponds to a sequence
        state and repurposing the GTR inference and ancestral reconstruction
        '''
        from treetime import GTR
        # Determine alphabet
        places = set()
        for node in self.tree.find_clades():
            if hasattr(node, 'attr'):
                if attr in node.attr and attr!=missing:
                    places.add(node.attr[attr])
        if root_state is not None:
            places.add(root_state)

        # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45)
        places = sorted(places)
        nc = len(places)
        if nc>180:
            self.logger("geo_inference: can't have more than 180 places!",1)
            return
        elif nc==1:
            self.logger("geo_inference: only one place found -- setting every internal node to %s!"%places[0],1)
            for node in self.tree.find_clades():
                node.attr[attr] = places[0]
                node.__setattr__(attr+'_transitions',[])
            return
        elif nc==0:
            self.logger("geo_inference: list of places is empty!",1)
            return

        # store previously reconstructed sequences
        nuc_seqs = {}
        nuc_muts = {}
        nuc_seq_LH = None
        if hasattr(self.tt.tree,'sequence_LH'):
            nuc_seq_LH = self.tt.tree.sequence_LH
        for node in self.tree.find_clades():
            if hasattr(node, 'sequence'):
                nuc_seqs[node] = node.sequence
            if hasattr(node, 'mutations'):
                nuc_muts[node] = node.mutations
                node.__delattr__('mutations')


        alphabet = {chr(65+i):place for i,place in enumerate(places)}
        sequence_gtr = self.tt.gtr
        myGeoGTR = GTR.custom(pi = np.ones(nc, dtype=float)/nc, W=np.ones((nc,nc)),
                              alphabet = np.array(sorted(alphabet.keys())))
        missing_char = chr(65+nc)
        alphabet[missing_char]=missing
        myGeoGTR.profile_map[missing_char] = np.ones(nc)
        alphabet_rev = {v:k for k,v in alphabet.items()}

        # set geo info to nodes as one letter sequence.
        self.tt.seq_len = 1
        for node in self.tree.get_terminals():
            if hasattr(node, 'attr'):
                if attr in node.attr:
                    node.sequence=np.array([alphabet_rev[node.attr[attr]]])
                else:
                    node.sequence=np.array([missing_char])
            else:
                node.sequence=np.array([missing_char])
        for node in self.tree.get_nonterminals():
            node.__delattr__('sequence')
        if root_state is not None:
            self.tree.root.split(n=1, branch_length=0.0)
            extra_clade = self.tree.root.clades[-1]
            extra_clade.name = "dummy_root_node"
            extra_clade.up = self.tree.root
            extra_clade.sequence = np.array([alphabet_rev[root_state]])
        self.tt.make_reduced_alignment()
        # set custom GTR model, run inference
        self.tt._gtr = myGeoGTR
        # import pdb; pdb.set_trace()
        tmp_use_mutation_length = self.tt.use_mutation_length
        self.tt.use_mutation_length=False
        self.tt.infer_ancestral_sequences(method='ml', infer_gtr=False,
            store_compressed=False, pc=5.0, marginal=True, normalized_rate=False)

        if root_state is not None:
            self.tree.prune(extra_clade)
        # restore the nucleotide sequence and mutations to maintain expected behavior
        self.tt.geogtr = self.tt.gtr
        self.tt.geogtr.alphabet_to_location = alphabet
        self.tt._gtr = sequence_gtr
        if hasattr(self.tt.tree,'sequence_LH'):
            self.tt.tree.geo_LH = self.tt.tree.sequence_LH
            self.tt.tree.sequence_LH = nuc_seq_LH
        for node in self.tree.find_clades():
            node.attr[attr] = alphabet[node.sequence[0]]
            if node in nuc_seqs:
                node.sequence = nuc_seqs[node]
            if node.up is not None:
                node.__setattr__(attr+'_transitions', node.mutations)
                if node in nuc_muts:
                    node.mutations = nuc_muts[node]
            # save marginal likelihoods if desired
            if report_confidence:
                node.attr[attr + "_entropy"] = sum([v * math.log(v+1E-20) for v in node.marginal_profile[0]]) * -1 / math.log(len(node.marginal_profile[0]))
                # javascript: vals.map((v) => v * Math.log(v + 1E-10)).reduce((a, b) => a + b, 0) * -1 / Math.log(vals.length);
                marginal = [(alphabet[self.tt.geogtr.alphabet[i]], node.marginal_profile[0][i]) for i in range(0, len(self.tt.geogtr.alphabet))]
                marginal.sort(key=lambda x: x[1], reverse=True) # sort on likelihoods
                marginal = [(a, b) for a, b in marginal if b > 0.01][:4] #only take stuff over 1% and the top 4 elements
                node.attr[attr + "_confidence"] = {a:b for a,b in marginal}
        self.tt.use_mutation_length=tmp_use_mutation_length

        # store saved attrs for save/restore functionality
        if not hasattr(self, "mugration_attrs"):
            self.mugration_attrs = []
        self.mugration_attrs.append(attr)
        if report_confidence:
            self.mugration_attrs.extend([attr + "_entropy", attr + "_confidence"])

    def restore_geo_inference(self, data, attr, confidence):
        if data == False:
            raise KeyError #yeah, not great
        for node in self.tree.find_clades():
            node.attr[attr] = data[node.name][attr]
            if confidence:
                node.attr[attr+"_confidence"] = data[node.name][attr+"_confidence"]
                node.attr[attr+"_entropy"] = data[node.name][attr+"_entropy"]
        if not hasattr(self, "mugration_attrs"):
            self.mugration_attrs = []
        self.mugration_attrs.append(attr)
        if confidence:
            self.mugration_attrs.extend([attr + "_entropy", attr + "_confidence"])



    def get_attr_list(self, get_attr):
        states = []
        for node in self.tree.find_clades():
            if get_attr in node.attr:
                states.append(node.attr[get_attr])
        return states

    def add_translations(self):
        '''
        translate the nucleotide sequence into the proteins specified
        in self.proteins. these are expected to be SeqFeatures
        '''
        from Bio import Seq

        # Sort proteins by start position of the corresponding SeqFeature entry.
        sorted_proteins = sorted(self.proteins.items(), key=lambda protein_pair: protein_pair[1].start)

        for node in self.tree.find_clades(order='preorder'):
            if not hasattr(node, "translations"):
                # Maintain genomic order of protein translations for easy
                # assembly by downstream functions.
                node.translations=OrderedDict()
                node.aa_mutations = {}

            for prot, feature in sorted_proteins:
                node.translations[prot] = Seq.translate(str(feature.extract(Seq.Seq("".join(node.sequence)))).replace('-', 'N'))

                if node.up is None:
                    node.aa_mutations[prot] = []
                else:
                    node.aa_mutations[prot] = [(a,pos,d) for pos, (a,d) in
                                               enumerate(zip(node.up.translations[prot],
                                                             node.translations[prot])) if a!=d]

        self.dump_attr.append('translations')


    def refine(self):
        '''
        add attributes for export, currently this is only muts and aa_muts
        '''
        self.tree.ladderize()
        for node in self.tree.find_clades():
            if node.up is not None:
                node.muts = ["".join(map(str, [a, pos+1, d])) for a,pos,d in node.mutations if '-' not in [a,d]]

                # Sort all deletions by position to enable identification of
                # deletions >1 bp below.
                deletions = sorted(
                    [(a,pos,d) for a,pos, d in node.mutations if '-' in [a,d]],
                    key=lambda mutation: mutation[1]
                )

                if len(deletions):
                    length = 0
                    for pi, (a,pos,d) in enumerate(deletions[:-1]):
                        if pos!=deletions[pi+1][1]-1:
                            if length==0:
                                node.muts.append(a+str(pos+1)+d)
                            elif d=='-':
                                node.muts.append("deletion %d-%d"%(pos-length, pos+1))
                            else:
                                node.muts.append("insertion %d-%d"%(pos-length, pos+1))
                        else:
                            length+=1
                    (a,pos,d) = deletions[-1]
                    if length==0:
                        node.muts.append(a+str(pos+1)+d)
                    elif d=='-':
                        node.muts.append("deletion %d-%d"%(pos-length, pos+1))
                    else:
                        node.muts.append("insertion %d-%d"%(pos-length, pos+1))


                node.aa_muts = {}
                if hasattr(node, 'translations'):
                    for prot in node.translations:
                        node.aa_muts[prot] = ["".join(map(str,[a,pos+1,d])) for a,pos,d in node.aa_mutations[prot]]
        for node in self.tree.find_clades(order="preorder"):
            if node.up is not None: #try:
                node.attr["div"] = node.up.attr["div"]+node.mutation_length
            else:
                node.attr["div"] = 0
        self.dump_attr.extend(['muts', 'aa_muts', 'aa_mutations', 'mutation_length', 'mutations'])


    def layout(self):
        """Add clade, xvalue, yvalue, mutation and trunk attributes to all nodes in tree"""
        clade = 0
        yvalue = self.tree.count_terminals()
        for node in self.tree.find_clades(order="preorder"):
            node.clade = clade
            clade += 1
            if node.up is not None: #try:
                node.xvalue = node.up.xvalue+node.mutation_length
                if self.is_timetree:
                    node.tvalue = node.numdate - self.tree.root.numdate
                else:
                    node.tvalue = 0
            else:
                node.xvalue = 0
                node.tvalue = 0
            if node.is_terminal():
                node.yvalue = yvalue
                yvalue -= 1
        for node in self.tree.get_nonterminals(order="postorder"):
            node.yvalue = np.mean([x.yvalue for x in node.clades])
        self.dump_attr.extend(['yvalue', 'xvalue', 'clade'])
        if self.is_timetree:
            self.dump_attr.extend(['tvalue'])


    def export(self, path = '', extra_attr = ['aa_muts', 'clade'], plain_export = 10, indent=None, write_seqs_json=True):
        '''
        export the tree data structure along with the sequence information as
        json files for display in web browsers.
        parameters:
            path    -- path (incl prefix) to which the output files are written.
                       filenames themselves are standardized  to *tree.json and *sequences.json
            extra_attr -- attributes of tree nodes that are exported to json
            plain_export -- store sequences are plain strings instead of
                            differences to root if number of differences exceeds
                            len(seq)/plain_export
        '''
        from Bio import Seq
        timetree_fname = path+'_tree.json'
        sequence_fname = path+'_sequences.json'
        tree_json = tree_to_json(self.tree.root, extra_attr=extra_attr)
        write_json(tree_json, timetree_fname, indent=indent)

        # prepare a json with sequence information to export.
        # first step: add the sequence & translations of the root as string
        elems = {}
        elems['root'] = {}
        elems['root']['nuc'] = "".join(self.tree.root.sequence)
        for prot,seq in self.tree.root.translations.items():
            elems['root'][prot] = seq

        # add sequence for every node in tree. code as difference to root
        # or as full strings.
        for node in self.tree.find_clades():
            if hasattr(node, "clade"):
                elems[node.clade] = {}
                # loop over proteins and nucleotide sequences
                for prot, seq in [('nuc', "".join(node.sequence))]+list(node.translations.items()):
                    differences = {pos:state for pos, (state, ancstate) in
                                enumerate(zip(seq, elems['root'][prot]))
                                if state!=ancstate}
                    if plain_export*len(differences)<=len(seq):
                        elems[node.clade][prot] = differences
                    else:
                        elems[node.clade][prot] = seq
        if write_seqs_json:
            write_json(elems, sequence_fname, indent=indent)
Exemple #3
0
class tree(object):
    """tree builds a phylgenetic tree from an alignment and exports it for web visualization"""
    def __init__(self, aln, proteins=None, verbose=2, logger=None, **kwarks):
        super(tree, self).__init__()
        self.aln = aln
        # self.nthreads = 2
        self.sequence_lookup = {seq.id:seq for seq in aln}
        self.nuc = kwarks['nuc'] if 'nuc' in kwarks else True
        self.dump_attr = [] # depreciated
        self.verbose = verbose
        if proteins!=None:
            self.proteins = proteins
        else:
            self.proteins={}
        if 'run_dir' not in kwarks:
            import random
            self.run_dir = '_'.join(['temp', time.strftime('%Y%m%d-%H%M%S',time.gmtime()), str(random.randint(0,1000000))])
        else:
            self.run_dir = kwarks['run_dir']
        if logger is None:
            def f(x,y):
                if y<self.verbose: print(x)
            self.logger = f
        else:
            self.logger=logger

    def getDateMin(self):
        return self.tree.root.date

    def getDateMax(self):
        dateMax = self.tree.root.date
        for node in self.tree.find_clades():
            if node.date > dateMax:
                dateMax = node.date
        return dateMax

    def dump(self, treefile, nodefile):
        from Bio import Phylo
        Phylo.write(self.tree, treefile, 'newick')
        node_props = {}
        for node in self.tree.find_clades():
            node_props[node.name] = {attr:node.__getattribute__(attr) for attr in self.dump_attr if hasattr(node, attr)}

        with myopen(nodefile, 'w') as nfile:
            from cPickle import dump
            dump(node_props, nfile)

    def check_newick(self, newick_file):
        try:
            tree = Phylo.parse(newick_file, 'newick').next()
            assert(set([x.name for x in tree.get_terminals()]) == set(self.sequence_lookup.keys()))
            return True
        except:
            return False

    def build_newick(self, newick_file, nthreads=2, method="raxml", raxml_options={},
                     iqtree_options={}, debug=False):
        make_dir(self.run_dir)
        os.chdir(self.run_dir)
        for seq in self.aln: seq.name=seq.id
        out_fname = os.path.join("..", newick_file)
        if method=="raxml":
            self.build_newick_raxml(out_fname, nthreads=nthreads, **raxml_options)
        elif method=="fasttree":
            self.build_newick_fasttree(out_fname)
        elif method=="iqtree":
            self.build_newick_iqtree(out_fname, **iqtree_options)
        os.chdir('..')
        self.logger("Saved new tree to %s"%out_fname, 1)
        if not debug:
            remove_dir(self.run_dir)


    def build_newick_fasttree(self, out_fname):
        from Bio import Phylo, AlignIO
        AlignIO.write(self.aln, 'temp.fasta', 'fasta')
        self.logger("Building tree with fasttree", 1)
        tree_cmd = ["fasttree"]
        if self.nuc: tree_cmd.append("-nt")

        tree_cmd.extend(["temp.fasta","1>",out_fname, "2>", "fasttree_stderr"])
        os.system(" ".join(tree_cmd))


    def build_newick_raxml(self, out_fname, nthreads=2, raxml_bin="raxml",
                           num_distinct_starting_trees=1, **kwargs):
        from Bio import Phylo, AlignIO
        import shutil
        self.logger("modified RAxML script - no branch length optimisation or time limit", 1)
        AlignIO.write(self.aln,"temp.phyx", "phylip-relaxed")
        if num_distinct_starting_trees == 1:
            cmd = raxml_bin + " -f d -T " + str(nthreads) + " -m GTRCAT -c 25 -p 235813 -n tre -s temp.phyx"
        else:
            self.logger("RAxML running with {} starting trees (longer but better...)".format(num_distinct_starting_trees), 1)
            cmd = raxml_bin + " -f d -T " + str(nthreads) + " -N " + str(num_distinct_starting_trees) + " -m GTRCAT -c 25 -p 235813 -n tre -s temp.phyx"

        try:
            with open("raxml.log", 'w') as fh:
                check_call(cmd, stdout=fh, stderr=STDOUT, shell=True)
                self.logger("RAXML COMPLETED.", 1)
        except CalledProcessError:
            self.logger("RAXML TREE FAILED - check {}/raxml.log".format(self.run_dir), 1)
            raise
        shutil.copy('RAxML_bestTree.tre', out_fname)


    def build_newick_iqtree(self, out_fname, nthreads=2, iqtree_bin="iqtree",
                            iqmodel="HKY",  **kwargs):
        from Bio import Phylo, AlignIO
        import shutil
        self.logger("modified RAxML script - no branch length optimisation or time limit", 1)
        aln_file = "temp.fasta"
        AlignIO.write(self.aln, aln_file, "fasta")
        with open(aln_file) as ifile:
            tmp_seqs = ifile.readlines()
        with open(aln_file, 'w') as ofile:
            for line in tmp_seqs:
                ofile.write(line.replace('/', '_X_X_'))

        if iqmodel:
            call = ["iqtree", "-nt", str(nthreads), "-s", aln_file, "-m", iqmodel, "-fast",
                ">", "iqtree.log"]
        else:
            call = ["iqtree", "-nt", str(nthreads), "-s", aln_file, ">", "iqtree.log"]

        os.system(" ".join(call))
        T = Phylo.read(aln_file+".treefile", 'newick')
        for n in T.get_terminals():
            n.name = n.name.replace('_X_X_','/')
        Phylo.write(T,out_fname, 'newick')


    def tt_from_file(self, infile, root='best', nodefile=None):
        self.is_timetree=False
        self.logger('Reading tree from file '+infile,2)
        dates  =   {seq.id:seq.attributes['num_date']
                    for seq in self.aln if 'date' in seq.attributes}
        self.tt = TreeTime(dates=dates, tree=str(infile), gtr='Jukes-Cantor',
                            aln = self.aln, verbose=self.verbose, fill_overhangs=True)
        if root:
            self.tt.reroot(root=root)
        self.tree = self.tt.tree

        for node in self.tree.find_clades():
            if node.is_terminal() and node.name in self.sequence_lookup:
                seq = self.sequence_lookup[node.name]
                node.attr = seq.attributes
                try:
                    node.attr['date'] = node.attr['date'].strftime('%Y-%m-%d')
                except:
                    pass
            else:
                node.attr = {}

        if nodefile is not None:
            self.logger('reading node properties from file: '+nodefile,2)
            with myopen(nodefile, 'r') as infile:
                from cPickle import load
                node_props = load(infile)
            for n in self.tree.find_clades():
                if n.name in node_props:
                    for attr in node_props[n.name]:
                        n.__setattr__(attr, node_props[n.name][attr])
                else:
                    self.logger("No node properties found for "+n.name,2)


    def ancestral(self, **kwarks):
        self.tt.optimize_seq_and_branch_len(infer_gtr=True, **kwarks)
        self.dump_attr.append('sequence')
        for node in self.tree.find_clades():
            if not hasattr(node,'attr'):
                node.attr = {}

    ## TODO REMOVE KWARKS - MAKE EXPLICIT
    def timetree(self, Tc=0.01, infer_gtr=True, reroot='best', resolve_polytomies=True,
                 max_iter=2, confidence=False, use_marginal=False, **kwarks):
        self.logger('estimating time tree...',2)
        if confidence and use_marginal:
            marginal = 'assign'
        else:
            marginal = confidence
        self.tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal,
                    resolve_polytomies=resolve_polytomies, max_iter=max_iter, **kwarks)
        self.logger('estimating time tree...done',3)
        self.dump_attr.extend(['numdate','date','sequence'])
        to_numdate = self.tt.date2dist.to_numdate
        for node in self.tree.find_clades():
            if hasattr(node,'attr'):
                node.attr['num_date'] = node.numdate
            else:
                node.attr = {'num_date':node.numdate}
            if confidence:
                node.attr["num_date_confidence"] = sorted(self.tt.get_max_posterior_region(node, fraction=0.9))

        self.is_timetree=True


    def save_timetree(self, fprefix, ttopts, cfopts):
        Phylo.write(self.tt.tree, fprefix+"_timetree.new", "newick")
        n = {}
        attrs = ["branch_length", "mutation_length", "clock_length", "dist2root",
                 "name", "mutations", "attr", "cseq", "sequence", "numdate"]
        for node in self.tt.tree.find_clades():
            n[node.name] = {}
            for x in attrs:
                n[node.name][x] = getattr(node, x)
        with open(fprefix+"_timetree.pickle", 'wb') as fh:
            pickle.dump({
                "timetree_options": ttopts,
                "clock_filter_options": cfopts,
                "nodes": n,
                "original_seqs": self.sequence_lookup.keys(),
            }, fh, protocol=pickle.HIGHEST_PROTOCOL)

    def restore_timetree_node_info(self, nodes):
        for node in self.tt.tree.find_clades():
            info = nodes[node.name]
            # print("restoring node info for node ", node.name)
            for k, v in info.iteritems():
                setattr(node, k, v)
        self.is_timetree=True


    def remove_outlier_clades(self, max_nodes=3, min_length=0.03):
        '''
        check whether one child clade of the root is small and the connecting branch
        is long. if so, move the root up and reset a few tree props
        Args:
            max_nodes   number of nodes beyond which the outliers are note removed
            min_length  minimal length of the branch connecting the outlier clade
                        to the rest of the tree to allow cutting.
        Returns:
            list of names of strains that have been removed
        '''
        R = self.tt.tree.root
        if len(R.clades)>2:
            return

        num_child_nodes = np.array([c.count_terminals() for c in R])
        putative_outlier = num_child_nodes.argmin()
        bl = np.sum([c.branch_length for c in R])
        if (bl>min_length and num_child_nodes[putative_outlier]<max_nodes):
            if num_child_nodes[putative_outlier]==1:
                print("removing \"{}\" which is an outlier clade".format(R.clades[putative_outlier].name))
            else:
                print("removing {} isolates which formed an outlier clade".format(num_child_nodes[putative_outlier]))
            self.tt.tree.root = R.clades[(1+putative_outlier)%2]
        self.tt.prepare_tree()
        return [c.name for c in R.clades[putative_outlier].get_terminals()]


    def geo_inference(self, attr, missing='?', root_state=None, report_confidence=False):
        '''
        infer a "mugration" model by pretending each region corresponds to a sequence
        state and repurposing the GTR inference and ancestral reconstruction
        '''
        from treetime import GTR
        # Determine alphabet
        places = set()
        for node in self.tree.find_clades():
            if hasattr(node, 'attr'):
                if attr in node.attr and attr!=missing:
                    places.add(node.attr[attr])
        if root_state is not None:
            places.add(root_state)

        # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45)
        places = sorted(places)
        nc = len(places)
        if nc>180:
            self.logger("geo_inference: can't have more than 180 places!",1)
            return
        elif nc==1:
            self.logger("geo_inference: only one place found -- setting every internal node to %s!"%places[0],1)
            for node in self.tree.find_clades():
                node.attr[attr] = places[0]
                node.__setattr__(attr+'_transitions',[])
            return
        elif nc==0:
            self.logger("geo_inference: list of places is empty!",1)
            return

        # store previously reconstructed sequences
        nuc_seqs = {}
        nuc_muts = {}
        nuc_seq_LH = None
        if hasattr(self.tt.tree,'sequence_LH'):
            nuc_seq_LH = self.tt.tree.sequence_LH
        for node in self.tree.find_clades():
            if hasattr(node, 'sequence'):
                nuc_seqs[node] = node.sequence
            if hasattr(node, 'mutations'):
                nuc_muts[node] = node.mutations
                node.__delattr__('mutations')


        alphabet = {chr(65+i):place for i,place in enumerate(places)}
        sequence_gtr = self.tt.gtr
        myGeoGTR = GTR.custom(pi = np.ones(nc, dtype=float)/nc, W=np.ones((nc,nc)),
                              alphabet = np.array(sorted(alphabet.keys())))
        missing_char = chr(65+nc)
        alphabet[missing_char]=missing
        myGeoGTR.profile_map[missing_char] = np.ones(nc)
        alphabet_rev = {v:k for k,v in alphabet.iteritems()}

        # set geo info to nodes as one letter sequence.
        self.tt.seq_len = 1
        for node in self.tree.get_terminals():
            if hasattr(node, 'attr'):
                if attr in node.attr:
                    node.sequence=np.array([alphabet_rev[node.attr[attr]]])
                else:
                    node.sequence=np.array([missing_char])
            else:
                node.sequence=np.array([missing_char])
        for node in self.tree.get_nonterminals():
            node.__delattr__('sequence')
        if root_state is not None:
            self.tree.root.split(n=1, branch_length=0.0)
            extra_clade = self.tree.root.clades[-1]
            extra_clade.name = "dummy_root_node"
            extra_clade.up = self.tree.root
            extra_clade.sequence = np.array([alphabet_rev[root_state]])
        self.tt.make_reduced_alignment()
        # set custom GTR model, run inference
        self.tt._gtr = myGeoGTR
        # import pdb; pdb.set_trace()
        tmp_use_mutation_length = self.tt.use_mutation_length
        self.tt.use_mutation_length=False
        self.tt.infer_ancestral_sequences(method='ml', infer_gtr=False,
            store_compressed=False, pc=5.0, marginal=True, normalized_rate=False)

        if root_state is not None:
            self.tree.prune(extra_clade)
        # restore the nucleotide sequence and mutations to maintain expected behavior
        self.tt.geogtr = self.tt.gtr
        self.tt.geogtr.alphabet_to_location = alphabet
        self.tt._gtr = sequence_gtr
        if hasattr(self.tt.tree,'sequence_LH'):
            self.tt.tree.geo_LH = self.tt.tree.sequence_LH
            self.tt.tree.sequence_LH = nuc_seq_LH
        for node in self.tree.find_clades():
            node.attr[attr] = alphabet[node.sequence[0]]
            if node in nuc_seqs:
                node.sequence = nuc_seqs[node]
            if node.up is not None:
                node.__setattr__(attr+'_transitions', node.mutations)
                if node in nuc_muts:
                    node.mutations = nuc_muts[node]
            # save marginal likelihoods if desired
            if report_confidence:
                node.attr[attr + "_entropy"] = sum([v * math.log(v+1E-20) for v in node.marginal_profile[0]]) * -1 / math.log(len(node.marginal_profile[0]))
                # javascript: vals.map((v) => v * Math.log(v + 1E-10)).reduce((a, b) => a + b, 0) * -1 / Math.log(vals.length);
                marginal = [(alphabet[self.tt.geogtr.alphabet[i]], node.marginal_profile[0][i]) for i in range(0, len(self.tt.geogtr.alphabet))]
                marginal.sort(key=lambda x: x[1], reverse=True) # sort on likelihoods
                marginal = [(a, b) for a, b in marginal if b > 0.01][:4] #only take stuff over 1% and the top 4 elements
                node.attr[attr + "_confidence"] = {a:b for a,b in marginal}
        self.tt.use_mutation_length=tmp_use_mutation_length

        # store saved attrs for save/restore functionality
        if not hasattr(self, "mugration_attrs"):
            self.mugration_attrs = []
        self.mugration_attrs.append(attr)
        if report_confidence:
            self.mugration_attrs.extend([attr + "_entropy", attr + "_confidence"])

    def restore_geo_inference(self, data, attr, confidence):
        if data == False:
            raise KeyError #yeah, not great
        for node in self.tree.find_clades():
            node.attr[attr] = data[node.name][attr]
            if confidence:
                node.attr[attr+"_confidence"] = data[node.name][attr+"_confidence"]
                node.attr[attr+"_entropy"] = data[node.name][attr+"_entropy"]
        if not hasattr(self, "mugration_attrs"):
            self.mugration_attrs = []
        self.mugration_attrs.append(attr)
        if confidence:
            self.mugration_attrs.extend([attr + "_entropy", attr + "_confidence"])



    def get_attr_list(self, get_attr):
        states = []
        for node in self.tree.find_clades():
            if get_attr in node.attr:
                states.append(node.attr[get_attr])
        return states

    def add_translations(self):
        '''
        translate the nucleotide sequence into the proteins specified
        in self.proteins. these are expected to be SeqFeatures
        '''
        from Bio import Seq

        # Sort proteins by start position of the corresponding SeqFeature entry.
        sorted_proteins = sorted(self.proteins.items(), key=lambda protein_pair: protein_pair[1].start)

        for node in self.tree.find_clades(order='preorder'):
            if not hasattr(node, "translations"):
                # Maintain genomic order of protein translations for easy
                # assembly by downstream functions.
                node.translations=OrderedDict()
                node.aa_mutations = {}

            for prot, feature in sorted_proteins:
                node.translations[prot] = Seq.translate(str(feature.extract(Seq.Seq("".join(node.sequence)))).replace('-', 'N'))

                if node.up is None:
                    node.aa_mutations[prot] = []
                else:
                    node.aa_mutations[prot] = [(a,pos,d) for pos, (a,d) in
                                               enumerate(zip(node.up.translations[prot],
                                                             node.translations[prot])) if a!=d]

        self.dump_attr.append('translations')


    def refine(self):
        '''
        add attributes for export, currently this is only muts and aa_muts
        '''
        self.tree.ladderize()
        for node in self.tree.find_clades():
            if node.up is not None:
                node.muts = ["".join(map(str, [a, pos+1, d])) for a,pos,d in node.mutations if '-' not in [a,d]]

                # Sort all deletions by position to enable identification of
                # deletions >1 bp below.
                deletions = sorted(
                    [(a,pos,d) for a,pos, d in node.mutations if '-' in [a,d]],
                    key=lambda mutation: mutation[1]
                )

                if len(deletions):
                    length = 0
                    for pi, (a,pos,d) in enumerate(deletions[:-1]):
                        if pos!=deletions[pi+1][1]-1:
                            if length==0:
                                node.muts.append(a+str(pos+1)+d)
                            elif d=='-':
                                node.muts.append("deletion %d-%d"%(pos-length, pos+1))
                            else:
                                node.muts.append("insertion %d-%d"%(pos-length, pos+1))
                        else:
                            length+=1
                    (a,pos,d) = deletions[-1]
                    if length==0:
                        node.muts.append(a+str(pos+1)+d)
                    elif d=='-':
                        node.muts.append("deletion %d-%d"%(pos-length, pos+1))
                    else:
                        node.muts.append("insertion %d-%d"%(pos-length, pos+1))


                node.aa_muts = {}
                if hasattr(node, 'translations'):
                    for prot in node.translations:
                        node.aa_muts[prot] = ["".join(map(str,[a,pos+1,d])) for a,pos,d in node.aa_mutations[prot]]
        for node in self.tree.find_clades(order="preorder"):
            if node.up is not None: #try:
                node.attr["div"] = node.up.attr["div"]+node.mutation_length
            else:
                node.attr["div"] = 0
        self.dump_attr.extend(['muts', 'aa_muts', 'aa_mutations', 'mutation_length', 'mutations'])


    def layout(self):
        """Add clade, xvalue, yvalue, mutation and trunk attributes to all nodes in tree"""
        clade = 0
        yvalue = self.tree.count_terminals()
        for node in self.tree.find_clades(order="preorder"):
            node.clade = clade
            clade += 1
            if node.up is not None: #try:
                node.xvalue = node.up.xvalue+node.mutation_length
                if self.is_timetree:
                    node.tvalue = node.numdate - self.tree.root.numdate
                else:
                    node.tvalue = 0
            else:
                node.xvalue = 0
                node.tvalue = 0
            if node.is_terminal():
                node.yvalue = yvalue
                yvalue -= 1
        for node in self.tree.get_nonterminals(order="postorder"):
            node.yvalue = np.mean([x.yvalue for x in node.clades])
        self.dump_attr.extend(['yvalue', 'xvalue', 'clade'])
        if self.is_timetree:
            self.dump_attr.extend(['tvalue'])


    def export(self, path = '', extra_attr = ['aa_muts', 'clade'], plain_export = 10, indent=None, write_seqs_json=True):
        '''
        export the tree data structure along with the sequence information as
        json files for display in web browsers.
        parameters:
            path    -- path (incl prefix) to which the output files are written.
                       filenames themselves are standardized  to *tree.json and *sequences.json
            extra_attr -- attributes of tree nodes that are exported to json
            plain_export -- store sequences are plain strings instead of
                            differences to root if number of differences exceeds
                            len(seq)/plain_export
        '''
        from Bio import Seq
        from itertools import izip
        timetree_fname = path+'_tree.json'
        sequence_fname = path+'_sequences.json'
        tree_json = tree_to_json(self.tree.root, extra_attr=extra_attr)
        write_json(tree_json, timetree_fname, indent=indent)

        # prepare a json with sequence information to export.
        # first step: add the sequence & translations of the root as string
        elems = {}
        elems['root'] = {}
        elems['root']['nuc'] = "".join(self.tree.root.sequence)
        for prot,seq in self.tree.root.translations.iteritems():
            elems['root'][prot] = seq

        # add sequence for every node in tree. code as difference to root
        # or as full strings.
        for node in self.tree.find_clades():
            if hasattr(node, "clade"):
                elems[node.clade] = {}
                # loop over proteins and nucleotide sequences
                for prot, seq in [('nuc', "".join(node.sequence))]+node.translations.items():
                    differences = {pos:state for pos, (state, ancstate) in
                                enumerate(izip(seq, elems['root'][prot]))
                                if state!=ancstate}
                    if plain_export*len(differences)<=len(seq):
                        elems[node.clade][prot] = differences
                    else:
                        elems[node.clade][prot] = seq
        if write_seqs_json:
            write_json(elems, sequence_fname, indent=indent)