def timetree(tree=None, aln=None, seq_meta=None, keeproot=False, confidence=False, resolve_polytomies=True, max_iter=2, infer_gtr=True, Tc=0.01, reroot='best', use_marginal=False, **kwarks): from treetime import TreeTime dates = {} for name, data in seq_meta.items(): num_date = parse_date(data["date"], date_fmt) if num_date is not None: dates[name] = num_date tt = TreeTime(tree=tree, aln=aln, dates=dates, gtr='JC69') if confidence and use_marginal: # estimate confidence intervals via marginal ML and assign marginal ML times to nodes marginal = 'assign' else: marginal = confidence tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal, resolve_polytomies=resolve_polytomies, max_iter=max_iter, **kwarks) for n in T.find_clades(): n.num_date = n.numdate # treetime convention is different from augur... # get 90% max posterior region) if confidence: n.num_date_confidence = list(tt.get_max_posterior_region(n, 0.9)) return tt
def timetree(tree=None, aln=None, ref=None, seq_meta=None, keeproot=False, confidence=False, resolve_polytomies=True, max_iter=2, dateLimits=None, infer_gtr=True, Tc=0.01, reroot='best', use_marginal=False, **kwarks): from treetime import TreeTime dL_int = None if dateLimits: dL_int = [int(x) for x in dateLimits] dL_int.sort() dates = {} for name, data in seq_meta.items(): num_date = parse_date(data["date"], date_fmt, dL_int) if num_date is not None: dates[name] = num_date #send ref, if is None, does no harm tt = TreeTime(tree=tree, aln=aln, ref=ref, dates=dates, gtr='JC69') if confidence and use_marginal: # estimate confidence intervals via marginal ML and assign marginal ML times to nodes marginal = 'assign' else: marginal = confidence #Length of VCF files means GTR model with gaps causes overestimation of mutation TO gaps #so gaps appear in internal nodes when no gaps at tips! To prevent.... pi = None if ref != None: #if VCF, fix pi pi = np.array([0.1618, 0.3188, 0.3176, 0.1618, 0.04]) #from real runs (Walker 2018) tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal, resolve_polytomies=resolve_polytomies, max_iter=max_iter, fixed_pi=pi, **kwarks) for n in T.find_clades(): n.num_date = n.numdate # treetime convention is different from augur... # get 90% max posterior region) if confidence: n.num_date_confidence = list(tt.get_max_posterior_region(n, 0.9)) return tt
def refine(tree=None, aln=None, ref=None, dates=None, branch_length_inference='auto', confidence=False, resolve_polytomies=True, max_iter=2, precision='auto', infer_gtr=True, Tc=0.01, reroot=None, use_marginal=False, fixed_pi=None, clock_rate=None, clock_std=None, clock_filter_iqd=None, verbosity=1, covariance=True, **kwarks): from treetime import TreeTime try: #Tc could be a number or 'opt' or 'skyline'. TreeTime expects a float or int if a number. Tc = float(Tc) except ValueError: True #let it remain a string if (ref is not None) and (fixed_pi is None): #if VCF, fix pi #Otherwise mutation TO gaps is overestimated b/c of seq length fixed_pi = [ ref.count(base) / len(ref) for base in ['A', 'C', 'G', 'T', '-'] ] if fixed_pi[-1] == 0: fixed_pi[-1] = 0.05 fixed_pi = [v - 0.01 for v in fixed_pi] if ref is not None: # VCF -> adjust branch length #set branch length mode explicitly if auto, as informative-site only #trees can have big branch lengths, making this set incorrectly in TreeTime if branch_length_inference == 'auto': branch_length_inference = 'joint' #send ref, if is None, does no harm tt = TreeTime(tree=tree, aln=aln, ref=ref, dates=dates, verbose=verbosity, gtr='JC69', precision=precision) # conditionally run clock-filter and remove bad tips if clock_filter_iqd: # treetime clock filter will mark, but not remove bad tips tt.clock_filter(reroot=reroot, n_iqd=clock_filter_iqd, plot=False) #use whatever was specified # remove them explicitly leaves = [x for x in tt.tree.get_terminals()] for n in leaves: if n.bad_branch: tt.tree.prune(n) print('pruning leaf ', n.name) # fix treetime set-up for new tree topology tt.prepare_tree() if confidence and use_marginal: # estimate confidence intervals via marginal ML and assign # marginal ML times to nodes marginal = 'assign' else: marginal = confidence # uncertainty of the the clock rate is relevant if confidence intervals are estimated if confidence and clock_std: vary_rate = clock_std # if standard devivation of clock is specified, use that elif (clock_rate is None) and confidence and covariance: vary_rate = True # if run in covariance mode, standard deviation can be estimated else: vary_rate = False # otherwise, rate uncertainty will be ignored tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal, branch_length_mode=branch_length_inference, resolve_polytomies=resolve_polytomies, max_iter=max_iter, fixed_pi=fixed_pi, fixed_clock_rate=clock_rate, vary_rate=vary_rate, use_covariation=covariance, **kwarks) if confidence: for n in tt.tree.find_clades(): n.num_date_confidence = list(tt.get_max_posterior_region(n, 0.9)) print( "\nInferred a time resolved phylogeny using TreeTime:" "\n\tSagulenko et al. TreeTime: Maximum-likelihood phylodynamic analysis" "\n\tVirus Evolution, vol 4, https://academic.oup.com/ve/article/4/1/vex042/4794731\n" ) return tt
def timetree(tree=None, aln=None, ref=None, dates=None, branch_length_mode='auto', confidence=False, resolve_polytomies=True, max_iter=2, infer_gtr=True, Tc=0.01, reroot=None, use_marginal=False, fixed_pi=None, clock_rate=None, n_iqd=None, verbosity=1, **kwarks): from treetime import TreeTime try: #Tc could be a number or 'opt' or 'skyline'. TreeTime expects a float or int if a number. Tc = float(Tc) except ValueError: True #let it remain a string if ref != None: #if VCF, fix pi #Otherwise mutation TO gaps is overestimated b/c of seq length fixed_pi = [ ref.count(base) / len(ref) for base in ['A', 'C', 'G', 'T', '-'] ] if fixed_pi[-1] == 0: fixed_pi[-1] = 0.05 fixed_pi = [v - 0.01 for v in fixed_pi] #set this explicitly if auto, as informative-site only trees can have big branch lengths, #making this set incorrectly in TreeTime if branch_length_mode == 'auto': branch_length_mode = 'joint' #send ref, if is None, does no harm tt = TreeTime(tree=tree, aln=aln, ref=ref, dates=dates, verbose=verbosity, gtr='JC69') if confidence and use_marginal: # estimate confidence intervals via marginal ML and assign marginal ML times to nodes marginal = 'assign' else: marginal = confidence tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal, branch_length_mode=branch_length_mode, resolve_polytomies=resolve_polytomies, max_iter=max_iter, fixed_pi=fixed_pi, fixed_clock_rate=clock_rate, n_iqd=n_iqd, **kwarks) if confidence: for n in tt.tree.find_clades(): n.num_date_confidence = list(tt.get_max_posterior_region(n, 0.9)) print( "\nInferred a time resolved phylogeny using TreeTime:" "\n\tSagulenko et al. TreeTime: Maximum-likelihood phylodynamic analysis" "\n\tVirus Evolution, vol 4, https://academic.oup.com/ve/article/4/1/vex042/4794731\n" ) return tt
def refine(tree=None, aln=None, ref=None, dates=None, branch_length_inference='auto', confidence=False, resolve_polytomies=True, max_iter=2, infer_gtr=True, Tc=0.01, reroot=None, use_marginal=False, fixed_pi=None, clock_rate=None, clock_std=None, clock_filter_iqd=None, verbosity=1, **kwarks): from treetime import TreeTime try: #Tc could be a number or 'opt' or 'skyline'. TreeTime expects a float or int if a number. Tc = float(Tc) except ValueError: True #let it remain a string if (ref is not None) and (fixed_pi is None): #if VCF, fix pi #Otherwise mutation TO gaps is overestimated b/c of seq length fixed_pi = [ref.count(base)/len(ref) for base in ['A','C','G','T','-']] if fixed_pi[-1] == 0: fixed_pi[-1] = 0.05 fixed_pi = [v-0.01 for v in fixed_pi] if ref is not None: # VCF -> adjust branch length #set branch length mode explicitly if auto, as informative-site only #trees can have big branch lengths, making this set incorrectly in TreeTime if branch_length_inference == 'auto': branch_length_inference = 'joint' #send ref, if is None, does no harm tt = TreeTime(tree=tree, aln=aln, ref=ref, dates=dates, verbose=verbosity, gtr='JC69') # conditionally run clock-filter and remove bad tips if clock_filter_iqd: # treetime clock filter will mark, but not remove bad tips tt.clock_filter(reroot='best', n_iqd=clock_filter_iqd, plot=False) # remove them explicitly leaves = [x for x in tt.tree.get_terminals()] for n in leaves: if n.bad_branch: tt.tree.prune(n) print('pruning leaf ', n.name) # fix treetime set-up for new tree topology tt.prepare_tree() if confidence and use_marginal: # estimate confidence intervals via marginal ML and assign # marginal ML times to nodes marginal = 'assign' else: marginal = confidence vary_rate = False if clock_rate and clock_std: vary_rate = clock_std else: vary_rate = True tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal, branch_length_mode=branch_length_inference, resolve_polytomies=resolve_polytomies, max_iter=max_iter, fixed_pi=fixed_pi, fixed_clock_rate=clock_rate, vary_rate=vary_rate, **kwarks) if confidence: for n in tt.tree.find_clades(): n.num_date_confidence = list(tt.get_max_posterior_region(n, 0.9)) print("\nInferred a time resolved phylogeny using TreeTime:" "\n\tSagulenko et al. TreeTime: Maximum-likelihood phylodynamic analysis" "\n\tVirus Evolution, vol 4, https://academic.oup.com/ve/article/4/1/vex042/4794731\n") return tt
class tree(object): """tree builds a phylgenetic tree from an alignment and exports it for web visualization""" def __init__(self, aln, proteins=None, verbose=2, logger=None, **kwarks): super(tree, self).__init__() self.aln = aln # self.nthreads = 2 self.sequence_lookup = {seq.id:seq for seq in aln} self.nuc = kwarks['nuc'] if 'nuc' in kwarks else True self.dump_attr = [] # depreciated self.verbose = verbose if proteins!=None: self.proteins = proteins else: self.proteins={} if 'run_dir' not in kwarks: import random self.run_dir = '_'.join(['temp', time.strftime('%Y%m%d-%H%M%S',time.gmtime()), str(random.randint(0,1000000))]) else: self.run_dir = kwarks['run_dir'] if logger is None: def f(x,y): if y<self.verbose: print(x) self.logger = f else: self.logger=logger def getDateMin(self): return self.tree.root.date def getDateMax(self): dateMax = self.tree.root.date for node in self.tree.find_clades(): if node.date > dateMax: dateMax = node.date return dateMax def dump(self, treefile, nodefile): from Bio import Phylo Phylo.write(self.tree, treefile, 'newick') node_props = {} for node in self.tree.find_clades(): node_props[node.name] = {attr:node.__getattribute__(attr) for attr in self.dump_attr if hasattr(node, attr)} with myopen(nodefile, 'w') as nfile: pickle.dump(node_props, nfile) def check_newick(self, newick_file): try: tree = Phylo.parse(newick_file, 'newick').next() assert(set([x.name for x in tree.get_terminals()]) == set(self.sequence_lookup.keys())) return True except: return False def build_newick(self, newick_file, nthreads=2, method="raxml", raxml_options={}, iqtree_options={}, debug=False): make_dir(self.run_dir) os.chdir(self.run_dir) for seq in self.aln: seq.name=seq.id out_fname = os.path.join("..", newick_file) if method=="raxml": self.build_newick_raxml(out_fname, nthreads=nthreads, **raxml_options) elif method=="fasttree": self.build_newick_fasttree(out_fname) elif method=="iqtree": self.build_newick_iqtree(out_fname, **iqtree_options) os.chdir('..') self.logger("Saved new tree to %s"%out_fname, 1) if not debug: remove_dir(self.run_dir) def build_newick_fasttree(self, out_fname): from Bio import Phylo, AlignIO AlignIO.write(self.aln, 'temp.fasta', 'fasta') self.logger("Building tree with fasttree", 1) tree_cmd = ["fasttree"] if self.nuc: tree_cmd.append("-nt") tree_cmd.extend(["temp.fasta","1>",out_fname, "2>", "fasttree_stderr"]) os.system(" ".join(tree_cmd)) def build_newick_raxml(self, out_fname, nthreads=2, raxml_bin="raxml", num_distinct_starting_trees=1, **kwargs): from Bio import Phylo, AlignIO import shutil self.logger("modified RAxML script - no branch length optimisation or time limit", 1) AlignIO.write(self.aln,"temp.phyx", "phylip-relaxed") if num_distinct_starting_trees == 1: cmd = raxml_bin + " -f d -T " + str(nthreads) + " -m GTRCAT -c 25 -p 235813 -n tre -s temp.phyx" else: self.logger("RAxML running with {} starting trees (longer but better...)".format(num_distinct_starting_trees), 1) cmd = raxml_bin + " -f d -T " + str(nthreads) + " -N " + str(num_distinct_starting_trees) + " -m GTRCAT -c 25 -p 235813 -n tre -s temp.phyx" try: with open("raxml.log", 'w') as fh: check_call(cmd, stdout=fh, stderr=STDOUT, shell=True) self.logger("RAXML COMPLETED.", 1) except CalledProcessError: self.logger("RAXML TREE FAILED - check {}/raxml.log".format(self.run_dir), 1) raise shutil.copy('RAxML_bestTree.tre', out_fname) def build_newick_iqtree(self, out_fname, nthreads=2, iqtree_bin="iqtree", iqmodel="HKY", **kwargs): from Bio import Phylo, AlignIO import shutil self.logger("modified RAxML script - no branch length optimisation or time limit", 1) aln_file = "temp.fasta" AlignIO.write(self.aln, aln_file, "fasta") with open(aln_file) as ifile: tmp_seqs = ifile.readlines() with open(aln_file, 'w') as ofile: for line in tmp_seqs: ofile.write(line.replace('/', '_X_X_')) if iqmodel: call = ["iqtree", "-nt", str(nthreads), "-s", aln_file, "-m", iqmodel, "-fast", ">", "iqtree.log"] else: call = ["iqtree", "-nt", str(nthreads), "-s", aln_file, ">", "iqtree.log"] os.system(" ".join(call)) T = Phylo.read(aln_file+".treefile", 'newick') for n in T.get_terminals(): n.name = n.name.replace('_X_X_','/') Phylo.write(T,out_fname, 'newick') def tt_from_file(self, infile, root='best', nodefile=None): self.is_timetree=False self.logger('Reading tree from file '+infile,2) dates = {seq.id:seq.attributes['num_date'] for seq in self.aln if 'date' in seq.attributes} self.tt = TreeTime(dates=dates, tree=str(infile), gtr='Jukes-Cantor', aln = self.aln, verbose=self.verbose, fill_overhangs=True) if root: self.tt.reroot(root=root) self.tree = self.tt.tree for node in self.tree.find_clades(): if node.is_terminal() and node.name in self.sequence_lookup: seq = self.sequence_lookup[node.name] node.attr = seq.attributes try: node.attr['date'] = node.attr['date'].strftime('%Y-%m-%d') except: pass else: node.attr = {} if nodefile is not None: self.logger('reading node properties from file: '+nodefile,2) with myopen(nodefile, 'r') as infile: node_props = pickle.load(infile) for n in self.tree.find_clades(): if n.name in node_props: for attr in node_props[n.name]: n.__setattr__(attr, node_props[n.name][attr]) else: self.logger("No node properties found for "+n.name,2) def ancestral(self, **kwarks): self.tt.optimize_seq_and_branch_len(infer_gtr=True, **kwarks) self.dump_attr.append('sequence') for node in self.tree.find_clades(): if not hasattr(node,'attr'): node.attr = {} ## TODO REMOVE KWARKS - MAKE EXPLICIT def timetree(self, Tc=0.01, infer_gtr=True, reroot='best', resolve_polytomies=True, max_iter=2, confidence=False, use_marginal=False, **kwarks): self.logger('estimating time tree...',2) if confidence and use_marginal: marginal = 'assign' else: marginal = confidence self.tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal, resolve_polytomies=resolve_polytomies, max_iter=max_iter, **kwarks) self.logger('estimating time tree...done',3) self.dump_attr.extend(['numdate','date','sequence']) to_numdate = self.tt.date2dist.to_numdate for node in self.tree.find_clades(): if hasattr(node,'attr'): node.attr['num_date'] = node.numdate else: node.attr = {'num_date':node.numdate} if confidence: node.attr["num_date_confidence"] = sorted(self.tt.get_max_posterior_region(node, fraction=0.9)) self.is_timetree=True def save_timetree(self, fprefix, ttopts, cfopts): Phylo.write(self.tt.tree, fprefix+"_timetree.new", "newick") n = {} attrs = ["branch_length", "mutation_length", "clock_length", "dist2root", "name", "mutations", "attr", "cseq", "sequence", "numdate"] for node in self.tt.tree.find_clades(): n[node.name] = {} for x in attrs: n[node.name][x] = getattr(node, x) with open(fprefix+"_timetree.pickle", 'wb') as fh: pickle.dump({ "timetree_options": ttopts, "clock_filter_options": cfopts, "nodes": n, "original_seqs": list(self.sequence_lookup.keys()), }, fh, protocol=pickle.HIGHEST_PROTOCOL) def restore_timetree_node_info(self, nodes): for node in self.tt.tree.find_clades(): info = nodes[node.name] # print("restoring node info for node ", node.name) for k, v in info.items(): setattr(node, k, v) self.is_timetree=True def remove_outlier_clades(self, max_nodes=3, min_length=0.03): ''' check whether one child clade of the root is small and the connecting branch is long. if so, move the root up and reset a few tree props Args: max_nodes number of nodes beyond which the outliers are note removed min_length minimal length of the branch connecting the outlier clade to the rest of the tree to allow cutting. Returns: list of names of strains that have been removed ''' R = self.tt.tree.root if len(R.clades)>2: return num_child_nodes = np.array([c.count_terminals() for c in R]) putative_outlier = num_child_nodes.argmin() bl = np.sum([c.branch_length for c in R]) if (bl>min_length and num_child_nodes[putative_outlier]<max_nodes): if num_child_nodes[putative_outlier]==1: print("removing \"{}\" which is an outlier clade".format(R.clades[putative_outlier].name)) else: print("removing {} isolates which formed an outlier clade".format(num_child_nodes[putative_outlier])) self.tt.tree.root = R.clades[(1+putative_outlier)%2] self.tt.prepare_tree() return [c.name for c in R.clades[putative_outlier].get_terminals()] def geo_inference(self, attr, missing='?', root_state=None, report_confidence=False): ''' infer a "mugration" model by pretending each region corresponds to a sequence state and repurposing the GTR inference and ancestral reconstruction ''' from treetime import GTR # Determine alphabet places = set() for node in self.tree.find_clades(): if hasattr(node, 'attr'): if attr in node.attr and attr!=missing: places.add(node.attr[attr]) if root_state is not None: places.add(root_state) # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45) places = sorted(places) nc = len(places) if nc>180: self.logger("geo_inference: can't have more than 180 places!",1) return elif nc==1: self.logger("geo_inference: only one place found -- setting every internal node to %s!"%places[0],1) for node in self.tree.find_clades(): node.attr[attr] = places[0] node.__setattr__(attr+'_transitions',[]) return elif nc==0: self.logger("geo_inference: list of places is empty!",1) return # store previously reconstructed sequences nuc_seqs = {} nuc_muts = {} nuc_seq_LH = None if hasattr(self.tt.tree,'sequence_LH'): nuc_seq_LH = self.tt.tree.sequence_LH for node in self.tree.find_clades(): if hasattr(node, 'sequence'): nuc_seqs[node] = node.sequence if hasattr(node, 'mutations'): nuc_muts[node] = node.mutations node.__delattr__('mutations') alphabet = {chr(65+i):place for i,place in enumerate(places)} sequence_gtr = self.tt.gtr myGeoGTR = GTR.custom(pi = np.ones(nc, dtype=float)/nc, W=np.ones((nc,nc)), alphabet = np.array(sorted(alphabet.keys()))) missing_char = chr(65+nc) alphabet[missing_char]=missing myGeoGTR.profile_map[missing_char] = np.ones(nc) alphabet_rev = {v:k for k,v in alphabet.items()} # set geo info to nodes as one letter sequence. self.tt.seq_len = 1 for node in self.tree.get_terminals(): if hasattr(node, 'attr'): if attr in node.attr: node.sequence=np.array([alphabet_rev[node.attr[attr]]]) else: node.sequence=np.array([missing_char]) else: node.sequence=np.array([missing_char]) for node in self.tree.get_nonterminals(): node.__delattr__('sequence') if root_state is not None: self.tree.root.split(n=1, branch_length=0.0) extra_clade = self.tree.root.clades[-1] extra_clade.name = "dummy_root_node" extra_clade.up = self.tree.root extra_clade.sequence = np.array([alphabet_rev[root_state]]) self.tt.make_reduced_alignment() # set custom GTR model, run inference self.tt._gtr = myGeoGTR # import pdb; pdb.set_trace() tmp_use_mutation_length = self.tt.use_mutation_length self.tt.use_mutation_length=False self.tt.infer_ancestral_sequences(method='ml', infer_gtr=False, store_compressed=False, pc=5.0, marginal=True, normalized_rate=False) if root_state is not None: self.tree.prune(extra_clade) # restore the nucleotide sequence and mutations to maintain expected behavior self.tt.geogtr = self.tt.gtr self.tt.geogtr.alphabet_to_location = alphabet self.tt._gtr = sequence_gtr if hasattr(self.tt.tree,'sequence_LH'): self.tt.tree.geo_LH = self.tt.tree.sequence_LH self.tt.tree.sequence_LH = nuc_seq_LH for node in self.tree.find_clades(): node.attr[attr] = alphabet[node.sequence[0]] if node in nuc_seqs: node.sequence = nuc_seqs[node] if node.up is not None: node.__setattr__(attr+'_transitions', node.mutations) if node in nuc_muts: node.mutations = nuc_muts[node] # save marginal likelihoods if desired if report_confidence: node.attr[attr + "_entropy"] = sum([v * math.log(v+1E-20) for v in node.marginal_profile[0]]) * -1 / math.log(len(node.marginal_profile[0])) # javascript: vals.map((v) => v * Math.log(v + 1E-10)).reduce((a, b) => a + b, 0) * -1 / Math.log(vals.length); marginal = [(alphabet[self.tt.geogtr.alphabet[i]], node.marginal_profile[0][i]) for i in range(0, len(self.tt.geogtr.alphabet))] marginal.sort(key=lambda x: x[1], reverse=True) # sort on likelihoods marginal = [(a, b) for a, b in marginal if b > 0.01][:4] #only take stuff over 1% and the top 4 elements node.attr[attr + "_confidence"] = {a:b for a,b in marginal} self.tt.use_mutation_length=tmp_use_mutation_length # store saved attrs for save/restore functionality if not hasattr(self, "mugration_attrs"): self.mugration_attrs = [] self.mugration_attrs.append(attr) if report_confidence: self.mugration_attrs.extend([attr + "_entropy", attr + "_confidence"]) def restore_geo_inference(self, data, attr, confidence): if data == False: raise KeyError #yeah, not great for node in self.tree.find_clades(): node.attr[attr] = data[node.name][attr] if confidence: node.attr[attr+"_confidence"] = data[node.name][attr+"_confidence"] node.attr[attr+"_entropy"] = data[node.name][attr+"_entropy"] if not hasattr(self, "mugration_attrs"): self.mugration_attrs = [] self.mugration_attrs.append(attr) if confidence: self.mugration_attrs.extend([attr + "_entropy", attr + "_confidence"]) def get_attr_list(self, get_attr): states = [] for node in self.tree.find_clades(): if get_attr in node.attr: states.append(node.attr[get_attr]) return states def add_translations(self): ''' translate the nucleotide sequence into the proteins specified in self.proteins. these are expected to be SeqFeatures ''' from Bio import Seq # Sort proteins by start position of the corresponding SeqFeature entry. sorted_proteins = sorted(self.proteins.items(), key=lambda protein_pair: protein_pair[1].start) for node in self.tree.find_clades(order='preorder'): if not hasattr(node, "translations"): # Maintain genomic order of protein translations for easy # assembly by downstream functions. node.translations=OrderedDict() node.aa_mutations = {} for prot, feature in sorted_proteins: node.translations[prot] = Seq.translate(str(feature.extract(Seq.Seq("".join(node.sequence)))).replace('-', 'N')) if node.up is None: node.aa_mutations[prot] = [] else: node.aa_mutations[prot] = [(a,pos,d) for pos, (a,d) in enumerate(zip(node.up.translations[prot], node.translations[prot])) if a!=d] self.dump_attr.append('translations') def refine(self): ''' add attributes for export, currently this is only muts and aa_muts ''' self.tree.ladderize() for node in self.tree.find_clades(): if node.up is not None: node.muts = ["".join(map(str, [a, pos+1, d])) for a,pos,d in node.mutations if '-' not in [a,d]] # Sort all deletions by position to enable identification of # deletions >1 bp below. deletions = sorted( [(a,pos,d) for a,pos, d in node.mutations if '-' in [a,d]], key=lambda mutation: mutation[1] ) if len(deletions): length = 0 for pi, (a,pos,d) in enumerate(deletions[:-1]): if pos!=deletions[pi+1][1]-1: if length==0: node.muts.append(a+str(pos+1)+d) elif d=='-': node.muts.append("deletion %d-%d"%(pos-length, pos+1)) else: node.muts.append("insertion %d-%d"%(pos-length, pos+1)) else: length+=1 (a,pos,d) = deletions[-1] if length==0: node.muts.append(a+str(pos+1)+d) elif d=='-': node.muts.append("deletion %d-%d"%(pos-length, pos+1)) else: node.muts.append("insertion %d-%d"%(pos-length, pos+1)) node.aa_muts = {} if hasattr(node, 'translations'): for prot in node.translations: node.aa_muts[prot] = ["".join(map(str,[a,pos+1,d])) for a,pos,d in node.aa_mutations[prot]] for node in self.tree.find_clades(order="preorder"): if node.up is not None: #try: node.attr["div"] = node.up.attr["div"]+node.mutation_length else: node.attr["div"] = 0 self.dump_attr.extend(['muts', 'aa_muts', 'aa_mutations', 'mutation_length', 'mutations']) def layout(self): """Add clade, xvalue, yvalue, mutation and trunk attributes to all nodes in tree""" clade = 0 yvalue = self.tree.count_terminals() for node in self.tree.find_clades(order="preorder"): node.clade = clade clade += 1 if node.up is not None: #try: node.xvalue = node.up.xvalue+node.mutation_length if self.is_timetree: node.tvalue = node.numdate - self.tree.root.numdate else: node.tvalue = 0 else: node.xvalue = 0 node.tvalue = 0 if node.is_terminal(): node.yvalue = yvalue yvalue -= 1 for node in self.tree.get_nonterminals(order="postorder"): node.yvalue = np.mean([x.yvalue for x in node.clades]) self.dump_attr.extend(['yvalue', 'xvalue', 'clade']) if self.is_timetree: self.dump_attr.extend(['tvalue']) def export(self, path = '', extra_attr = ['aa_muts', 'clade'], plain_export = 10, indent=None, write_seqs_json=True): ''' export the tree data structure along with the sequence information as json files for display in web browsers. parameters: path -- path (incl prefix) to which the output files are written. filenames themselves are standardized to *tree.json and *sequences.json extra_attr -- attributes of tree nodes that are exported to json plain_export -- store sequences are plain strings instead of differences to root if number of differences exceeds len(seq)/plain_export ''' from Bio import Seq timetree_fname = path+'_tree.json' sequence_fname = path+'_sequences.json' tree_json = tree_to_json(self.tree.root, extra_attr=extra_attr) write_json(tree_json, timetree_fname, indent=indent) # prepare a json with sequence information to export. # first step: add the sequence & translations of the root as string elems = {} elems['root'] = {} elems['root']['nuc'] = "".join(self.tree.root.sequence) for prot,seq in self.tree.root.translations.items(): elems['root'][prot] = seq # add sequence for every node in tree. code as difference to root # or as full strings. for node in self.tree.find_clades(): if hasattr(node, "clade"): elems[node.clade] = {} # loop over proteins and nucleotide sequences for prot, seq in [('nuc', "".join(node.sequence))]+list(node.translations.items()): differences = {pos:state for pos, (state, ancstate) in enumerate(zip(seq, elems['root'][prot])) if state!=ancstate} if plain_export*len(differences)<=len(seq): elems[node.clade][prot] = differences else: elems[node.clade][prot] = seq if write_seqs_json: write_json(elems, sequence_fname, indent=indent)
class tree(object): """tree builds a phylgenetic tree from an alignment and exports it for web visualization""" def __init__(self, aln, proteins=None, verbose=2, logger=None, **kwarks): super(tree, self).__init__() self.aln = aln # self.nthreads = 2 self.sequence_lookup = {seq.id:seq for seq in aln} self.nuc = kwarks['nuc'] if 'nuc' in kwarks else True self.dump_attr = [] # depreciated self.verbose = verbose if proteins!=None: self.proteins = proteins else: self.proteins={} if 'run_dir' not in kwarks: import random self.run_dir = '_'.join(['temp', time.strftime('%Y%m%d-%H%M%S',time.gmtime()), str(random.randint(0,1000000))]) else: self.run_dir = kwarks['run_dir'] if logger is None: def f(x,y): if y<self.verbose: print(x) self.logger = f else: self.logger=logger def getDateMin(self): return self.tree.root.date def getDateMax(self): dateMax = self.tree.root.date for node in self.tree.find_clades(): if node.date > dateMax: dateMax = node.date return dateMax def dump(self, treefile, nodefile): from Bio import Phylo Phylo.write(self.tree, treefile, 'newick') node_props = {} for node in self.tree.find_clades(): node_props[node.name] = {attr:node.__getattribute__(attr) for attr in self.dump_attr if hasattr(node, attr)} with myopen(nodefile, 'w') as nfile: from cPickle import dump dump(node_props, nfile) def check_newick(self, newick_file): try: tree = Phylo.parse(newick_file, 'newick').next() assert(set([x.name for x in tree.get_terminals()]) == set(self.sequence_lookup.keys())) return True except: return False def build_newick(self, newick_file, nthreads=2, method="raxml", raxml_options={}, iqtree_options={}, debug=False): make_dir(self.run_dir) os.chdir(self.run_dir) for seq in self.aln: seq.name=seq.id out_fname = os.path.join("..", newick_file) if method=="raxml": self.build_newick_raxml(out_fname, nthreads=nthreads, **raxml_options) elif method=="fasttree": self.build_newick_fasttree(out_fname) elif method=="iqtree": self.build_newick_iqtree(out_fname, **iqtree_options) os.chdir('..') self.logger("Saved new tree to %s"%out_fname, 1) if not debug: remove_dir(self.run_dir) def build_newick_fasttree(self, out_fname): from Bio import Phylo, AlignIO AlignIO.write(self.aln, 'temp.fasta', 'fasta') self.logger("Building tree with fasttree", 1) tree_cmd = ["fasttree"] if self.nuc: tree_cmd.append("-nt") tree_cmd.extend(["temp.fasta","1>",out_fname, "2>", "fasttree_stderr"]) os.system(" ".join(tree_cmd)) def build_newick_raxml(self, out_fname, nthreads=2, raxml_bin="raxml", num_distinct_starting_trees=1, **kwargs): from Bio import Phylo, AlignIO import shutil self.logger("modified RAxML script - no branch length optimisation or time limit", 1) AlignIO.write(self.aln,"temp.phyx", "phylip-relaxed") if num_distinct_starting_trees == 1: cmd = raxml_bin + " -f d -T " + str(nthreads) + " -m GTRCAT -c 25 -p 235813 -n tre -s temp.phyx" else: self.logger("RAxML running with {} starting trees (longer but better...)".format(num_distinct_starting_trees), 1) cmd = raxml_bin + " -f d -T " + str(nthreads) + " -N " + str(num_distinct_starting_trees) + " -m GTRCAT -c 25 -p 235813 -n tre -s temp.phyx" try: with open("raxml.log", 'w') as fh: check_call(cmd, stdout=fh, stderr=STDOUT, shell=True) self.logger("RAXML COMPLETED.", 1) except CalledProcessError: self.logger("RAXML TREE FAILED - check {}/raxml.log".format(self.run_dir), 1) raise shutil.copy('RAxML_bestTree.tre', out_fname) def build_newick_iqtree(self, out_fname, nthreads=2, iqtree_bin="iqtree", iqmodel="HKY", **kwargs): from Bio import Phylo, AlignIO import shutil self.logger("modified RAxML script - no branch length optimisation or time limit", 1) aln_file = "temp.fasta" AlignIO.write(self.aln, aln_file, "fasta") with open(aln_file) as ifile: tmp_seqs = ifile.readlines() with open(aln_file, 'w') as ofile: for line in tmp_seqs: ofile.write(line.replace('/', '_X_X_')) if iqmodel: call = ["iqtree", "-nt", str(nthreads), "-s", aln_file, "-m", iqmodel, "-fast", ">", "iqtree.log"] else: call = ["iqtree", "-nt", str(nthreads), "-s", aln_file, ">", "iqtree.log"] os.system(" ".join(call)) T = Phylo.read(aln_file+".treefile", 'newick') for n in T.get_terminals(): n.name = n.name.replace('_X_X_','/') Phylo.write(T,out_fname, 'newick') def tt_from_file(self, infile, root='best', nodefile=None): self.is_timetree=False self.logger('Reading tree from file '+infile,2) dates = {seq.id:seq.attributes['num_date'] for seq in self.aln if 'date' in seq.attributes} self.tt = TreeTime(dates=dates, tree=str(infile), gtr='Jukes-Cantor', aln = self.aln, verbose=self.verbose, fill_overhangs=True) if root: self.tt.reroot(root=root) self.tree = self.tt.tree for node in self.tree.find_clades(): if node.is_terminal() and node.name in self.sequence_lookup: seq = self.sequence_lookup[node.name] node.attr = seq.attributes try: node.attr['date'] = node.attr['date'].strftime('%Y-%m-%d') except: pass else: node.attr = {} if nodefile is not None: self.logger('reading node properties from file: '+nodefile,2) with myopen(nodefile, 'r') as infile: from cPickle import load node_props = load(infile) for n in self.tree.find_clades(): if n.name in node_props: for attr in node_props[n.name]: n.__setattr__(attr, node_props[n.name][attr]) else: self.logger("No node properties found for "+n.name,2) def ancestral(self, **kwarks): self.tt.optimize_seq_and_branch_len(infer_gtr=True, **kwarks) self.dump_attr.append('sequence') for node in self.tree.find_clades(): if not hasattr(node,'attr'): node.attr = {} ## TODO REMOVE KWARKS - MAKE EXPLICIT def timetree(self, Tc=0.01, infer_gtr=True, reroot='best', resolve_polytomies=True, max_iter=2, confidence=False, use_marginal=False, **kwarks): self.logger('estimating time tree...',2) if confidence and use_marginal: marginal = 'assign' else: marginal = confidence self.tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal, resolve_polytomies=resolve_polytomies, max_iter=max_iter, **kwarks) self.logger('estimating time tree...done',3) self.dump_attr.extend(['numdate','date','sequence']) to_numdate = self.tt.date2dist.to_numdate for node in self.tree.find_clades(): if hasattr(node,'attr'): node.attr['num_date'] = node.numdate else: node.attr = {'num_date':node.numdate} if confidence: node.attr["num_date_confidence"] = sorted(self.tt.get_max_posterior_region(node, fraction=0.9)) self.is_timetree=True def save_timetree(self, fprefix, ttopts, cfopts): Phylo.write(self.tt.tree, fprefix+"_timetree.new", "newick") n = {} attrs = ["branch_length", "mutation_length", "clock_length", "dist2root", "name", "mutations", "attr", "cseq", "sequence", "numdate"] for node in self.tt.tree.find_clades(): n[node.name] = {} for x in attrs: n[node.name][x] = getattr(node, x) with open(fprefix+"_timetree.pickle", 'wb') as fh: pickle.dump({ "timetree_options": ttopts, "clock_filter_options": cfopts, "nodes": n, "original_seqs": self.sequence_lookup.keys(), }, fh, protocol=pickle.HIGHEST_PROTOCOL) def restore_timetree_node_info(self, nodes): for node in self.tt.tree.find_clades(): info = nodes[node.name] # print("restoring node info for node ", node.name) for k, v in info.iteritems(): setattr(node, k, v) self.is_timetree=True def remove_outlier_clades(self, max_nodes=3, min_length=0.03): ''' check whether one child clade of the root is small and the connecting branch is long. if so, move the root up and reset a few tree props Args: max_nodes number of nodes beyond which the outliers are note removed min_length minimal length of the branch connecting the outlier clade to the rest of the tree to allow cutting. Returns: list of names of strains that have been removed ''' R = self.tt.tree.root if len(R.clades)>2: return num_child_nodes = np.array([c.count_terminals() for c in R]) putative_outlier = num_child_nodes.argmin() bl = np.sum([c.branch_length for c in R]) if (bl>min_length and num_child_nodes[putative_outlier]<max_nodes): if num_child_nodes[putative_outlier]==1: print("removing \"{}\" which is an outlier clade".format(R.clades[putative_outlier].name)) else: print("removing {} isolates which formed an outlier clade".format(num_child_nodes[putative_outlier])) self.tt.tree.root = R.clades[(1+putative_outlier)%2] self.tt.prepare_tree() return [c.name for c in R.clades[putative_outlier].get_terminals()] def geo_inference(self, attr, missing='?', root_state=None, report_confidence=False): ''' infer a "mugration" model by pretending each region corresponds to a sequence state and repurposing the GTR inference and ancestral reconstruction ''' from treetime import GTR # Determine alphabet places = set() for node in self.tree.find_clades(): if hasattr(node, 'attr'): if attr in node.attr and attr!=missing: places.add(node.attr[attr]) if root_state is not None: places.add(root_state) # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45) places = sorted(places) nc = len(places) if nc>180: self.logger("geo_inference: can't have more than 180 places!",1) return elif nc==1: self.logger("geo_inference: only one place found -- setting every internal node to %s!"%places[0],1) for node in self.tree.find_clades(): node.attr[attr] = places[0] node.__setattr__(attr+'_transitions',[]) return elif nc==0: self.logger("geo_inference: list of places is empty!",1) return # store previously reconstructed sequences nuc_seqs = {} nuc_muts = {} nuc_seq_LH = None if hasattr(self.tt.tree,'sequence_LH'): nuc_seq_LH = self.tt.tree.sequence_LH for node in self.tree.find_clades(): if hasattr(node, 'sequence'): nuc_seqs[node] = node.sequence if hasattr(node, 'mutations'): nuc_muts[node] = node.mutations node.__delattr__('mutations') alphabet = {chr(65+i):place for i,place in enumerate(places)} sequence_gtr = self.tt.gtr myGeoGTR = GTR.custom(pi = np.ones(nc, dtype=float)/nc, W=np.ones((nc,nc)), alphabet = np.array(sorted(alphabet.keys()))) missing_char = chr(65+nc) alphabet[missing_char]=missing myGeoGTR.profile_map[missing_char] = np.ones(nc) alphabet_rev = {v:k for k,v in alphabet.iteritems()} # set geo info to nodes as one letter sequence. self.tt.seq_len = 1 for node in self.tree.get_terminals(): if hasattr(node, 'attr'): if attr in node.attr: node.sequence=np.array([alphabet_rev[node.attr[attr]]]) else: node.sequence=np.array([missing_char]) else: node.sequence=np.array([missing_char]) for node in self.tree.get_nonterminals(): node.__delattr__('sequence') if root_state is not None: self.tree.root.split(n=1, branch_length=0.0) extra_clade = self.tree.root.clades[-1] extra_clade.name = "dummy_root_node" extra_clade.up = self.tree.root extra_clade.sequence = np.array([alphabet_rev[root_state]]) self.tt.make_reduced_alignment() # set custom GTR model, run inference self.tt._gtr = myGeoGTR # import pdb; pdb.set_trace() tmp_use_mutation_length = self.tt.use_mutation_length self.tt.use_mutation_length=False self.tt.infer_ancestral_sequences(method='ml', infer_gtr=False, store_compressed=False, pc=5.0, marginal=True, normalized_rate=False) if root_state is not None: self.tree.prune(extra_clade) # restore the nucleotide sequence and mutations to maintain expected behavior self.tt.geogtr = self.tt.gtr self.tt.geogtr.alphabet_to_location = alphabet self.tt._gtr = sequence_gtr if hasattr(self.tt.tree,'sequence_LH'): self.tt.tree.geo_LH = self.tt.tree.sequence_LH self.tt.tree.sequence_LH = nuc_seq_LH for node in self.tree.find_clades(): node.attr[attr] = alphabet[node.sequence[0]] if node in nuc_seqs: node.sequence = nuc_seqs[node] if node.up is not None: node.__setattr__(attr+'_transitions', node.mutations) if node in nuc_muts: node.mutations = nuc_muts[node] # save marginal likelihoods if desired if report_confidence: node.attr[attr + "_entropy"] = sum([v * math.log(v+1E-20) for v in node.marginal_profile[0]]) * -1 / math.log(len(node.marginal_profile[0])) # javascript: vals.map((v) => v * Math.log(v + 1E-10)).reduce((a, b) => a + b, 0) * -1 / Math.log(vals.length); marginal = [(alphabet[self.tt.geogtr.alphabet[i]], node.marginal_profile[0][i]) for i in range(0, len(self.tt.geogtr.alphabet))] marginal.sort(key=lambda x: x[1], reverse=True) # sort on likelihoods marginal = [(a, b) for a, b in marginal if b > 0.01][:4] #only take stuff over 1% and the top 4 elements node.attr[attr + "_confidence"] = {a:b for a,b in marginal} self.tt.use_mutation_length=tmp_use_mutation_length # store saved attrs for save/restore functionality if not hasattr(self, "mugration_attrs"): self.mugration_attrs = [] self.mugration_attrs.append(attr) if report_confidence: self.mugration_attrs.extend([attr + "_entropy", attr + "_confidence"]) def restore_geo_inference(self, data, attr, confidence): if data == False: raise KeyError #yeah, not great for node in self.tree.find_clades(): node.attr[attr] = data[node.name][attr] if confidence: node.attr[attr+"_confidence"] = data[node.name][attr+"_confidence"] node.attr[attr+"_entropy"] = data[node.name][attr+"_entropy"] if not hasattr(self, "mugration_attrs"): self.mugration_attrs = [] self.mugration_attrs.append(attr) if confidence: self.mugration_attrs.extend([attr + "_entropy", attr + "_confidence"]) def get_attr_list(self, get_attr): states = [] for node in self.tree.find_clades(): if get_attr in node.attr: states.append(node.attr[get_attr]) return states def add_translations(self): ''' translate the nucleotide sequence into the proteins specified in self.proteins. these are expected to be SeqFeatures ''' from Bio import Seq # Sort proteins by start position of the corresponding SeqFeature entry. sorted_proteins = sorted(self.proteins.items(), key=lambda protein_pair: protein_pair[1].start) for node in self.tree.find_clades(order='preorder'): if not hasattr(node, "translations"): # Maintain genomic order of protein translations for easy # assembly by downstream functions. node.translations=OrderedDict() node.aa_mutations = {} for prot, feature in sorted_proteins: node.translations[prot] = Seq.translate(str(feature.extract(Seq.Seq("".join(node.sequence)))).replace('-', 'N')) if node.up is None: node.aa_mutations[prot] = [] else: node.aa_mutations[prot] = [(a,pos,d) for pos, (a,d) in enumerate(zip(node.up.translations[prot], node.translations[prot])) if a!=d] self.dump_attr.append('translations') def refine(self): ''' add attributes for export, currently this is only muts and aa_muts ''' self.tree.ladderize() for node in self.tree.find_clades(): if node.up is not None: node.muts = ["".join(map(str, [a, pos+1, d])) for a,pos,d in node.mutations if '-' not in [a,d]] # Sort all deletions by position to enable identification of # deletions >1 bp below. deletions = sorted( [(a,pos,d) for a,pos, d in node.mutations if '-' in [a,d]], key=lambda mutation: mutation[1] ) if len(deletions): length = 0 for pi, (a,pos,d) in enumerate(deletions[:-1]): if pos!=deletions[pi+1][1]-1: if length==0: node.muts.append(a+str(pos+1)+d) elif d=='-': node.muts.append("deletion %d-%d"%(pos-length, pos+1)) else: node.muts.append("insertion %d-%d"%(pos-length, pos+1)) else: length+=1 (a,pos,d) = deletions[-1] if length==0: node.muts.append(a+str(pos+1)+d) elif d=='-': node.muts.append("deletion %d-%d"%(pos-length, pos+1)) else: node.muts.append("insertion %d-%d"%(pos-length, pos+1)) node.aa_muts = {} if hasattr(node, 'translations'): for prot in node.translations: node.aa_muts[prot] = ["".join(map(str,[a,pos+1,d])) for a,pos,d in node.aa_mutations[prot]] for node in self.tree.find_clades(order="preorder"): if node.up is not None: #try: node.attr["div"] = node.up.attr["div"]+node.mutation_length else: node.attr["div"] = 0 self.dump_attr.extend(['muts', 'aa_muts', 'aa_mutations', 'mutation_length', 'mutations']) def layout(self): """Add clade, xvalue, yvalue, mutation and trunk attributes to all nodes in tree""" clade = 0 yvalue = self.tree.count_terminals() for node in self.tree.find_clades(order="preorder"): node.clade = clade clade += 1 if node.up is not None: #try: node.xvalue = node.up.xvalue+node.mutation_length if self.is_timetree: node.tvalue = node.numdate - self.tree.root.numdate else: node.tvalue = 0 else: node.xvalue = 0 node.tvalue = 0 if node.is_terminal(): node.yvalue = yvalue yvalue -= 1 for node in self.tree.get_nonterminals(order="postorder"): node.yvalue = np.mean([x.yvalue for x in node.clades]) self.dump_attr.extend(['yvalue', 'xvalue', 'clade']) if self.is_timetree: self.dump_attr.extend(['tvalue']) def export(self, path = '', extra_attr = ['aa_muts', 'clade'], plain_export = 10, indent=None, write_seqs_json=True): ''' export the tree data structure along with the sequence information as json files for display in web browsers. parameters: path -- path (incl prefix) to which the output files are written. filenames themselves are standardized to *tree.json and *sequences.json extra_attr -- attributes of tree nodes that are exported to json plain_export -- store sequences are plain strings instead of differences to root if number of differences exceeds len(seq)/plain_export ''' from Bio import Seq from itertools import izip timetree_fname = path+'_tree.json' sequence_fname = path+'_sequences.json' tree_json = tree_to_json(self.tree.root, extra_attr=extra_attr) write_json(tree_json, timetree_fname, indent=indent) # prepare a json with sequence information to export. # first step: add the sequence & translations of the root as string elems = {} elems['root'] = {} elems['root']['nuc'] = "".join(self.tree.root.sequence) for prot,seq in self.tree.root.translations.iteritems(): elems['root'][prot] = seq # add sequence for every node in tree. code as difference to root # or as full strings. for node in self.tree.find_clades(): if hasattr(node, "clade"): elems[node.clade] = {} # loop over proteins and nucleotide sequences for prot, seq in [('nuc', "".join(node.sequence))]+node.translations.items(): differences = {pos:state for pos, (state, ancstate) in enumerate(izip(seq, elems['root'][prot])) if state!=ancstate} if plain_export*len(differences)<=len(seq): elems[node.clade][prot] = differences else: elems[node.clade][prot] = seq if write_seqs_json: write_json(elems, sequence_fname, indent=indent)