def treetime_from_newick(gtr, infile): """ Create TreeTime object and load phylogenetic tree from newick file Args: - infile(str): path to the newick file. Returns: - tanc(TreeTime): tree time object with phylogenetic tree set and required parameters assigned to the nodes. """ tanc = TreeTime(gtr) tanc.tree = Phylo.read(infile, 'newick') tanc.set_additional_tree_params() return tanc
def tt_from_file(self, infile, root='best', nodefile=None): self.is_timetree=False self.logger('Reading tree from file '+infile,2) dates = {seq.id:seq.attributes['num_date'] for seq in self.aln if 'date' in seq.attributes} self.tt = TreeTime(dates=dates, tree=str(infile), gtr='Jukes-Cantor', aln = self.aln, verbose=self.verbose, fill_overhangs=True) if root: self.tt.reroot(root=root) self.tree = self.tt.tree for node in self.tree.find_clades(): if node.is_terminal() and node.name in self.sequence_lookup: seq = self.sequence_lookup[node.name] node.attr = seq.attributes try: node.attr['date'] = node.attr['date'].strftime('%Y-%m-%d') except: pass else: node.attr = {} if nodefile is not None: self.logger('reading node properties from file: '+nodefile,2) with myopen(nodefile, 'r') as infile: node_props = pickle.load(infile) for n in self.tree.find_clades(): if n.name in node_props: for attr in node_props[n.name]: n.__setattr__(attr, node_props[n.name][attr]) else: self.logger("No node properties found for "+n.name,2)
def refine(tree=None, aln=None, ref=None, dates=None, branch_length_inference='auto', confidence=False, resolve_polytomies=True, max_iter=2, infer_gtr=True, Tc=0.01, reroot=None, use_marginal=False, fixed_pi=None, clock_rate=None, clock_std=None, clock_filter_iqd=None, verbosity=1, **kwarks): from treetime import TreeTime try: #Tc could be a number or 'opt' or 'skyline'. TreeTime expects a float or int if a number. Tc = float(Tc) except ValueError: True #let it remain a string if (ref is not None) and (fixed_pi is None): #if VCF, fix pi #Otherwise mutation TO gaps is overestimated b/c of seq length fixed_pi = [ref.count(base)/len(ref) for base in ['A','C','G','T','-']] if fixed_pi[-1] == 0: fixed_pi[-1] = 0.05 fixed_pi = [v-0.01 for v in fixed_pi] if ref is not None: # VCF -> adjust branch length #set branch length mode explicitly if auto, as informative-site only #trees can have big branch lengths, making this set incorrectly in TreeTime if branch_length_inference == 'auto': branch_length_inference = 'joint' #send ref, if is None, does no harm tt = TreeTime(tree=tree, aln=aln, ref=ref, dates=dates, verbose=verbosity, gtr='JC69') # conditionally run clock-filter and remove bad tips if clock_filter_iqd: # treetime clock filter will mark, but not remove bad tips tt.clock_filter(reroot='best', n_iqd=clock_filter_iqd, plot=False) # remove them explicitly leaves = [x for x in tt.tree.get_terminals()] for n in leaves: if n.bad_branch: tt.tree.prune(n) print('pruning leaf ', n.name) # fix treetime set-up for new tree topology tt.prepare_tree() if confidence and use_marginal: # estimate confidence intervals via marginal ML and assign # marginal ML times to nodes marginal = 'assign' else: marginal = confidence vary_rate = False if clock_rate and clock_std: vary_rate = clock_std else: vary_rate = True tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal, branch_length_mode=branch_length_inference, resolve_polytomies=resolve_polytomies, max_iter=max_iter, fixed_pi=fixed_pi, fixed_clock_rate=clock_rate, vary_rate=vary_rate, **kwarks) if confidence: for n in tt.tree.find_clades(): n.num_date_confidence = list(tt.get_max_posterior_region(n, 0.9)) print("\nInferred a time resolved phylogeny using TreeTime:" "\n\tSagulenko et al. TreeTime: Maximum-likelihood phylodynamic analysis" "\n\tVirus Evolution, vol 4, https://academic.oup.com/ve/article/4/1/vex042/4794731\n") return tt
def estimate_clock_model(params): """ implementing treetime clock """ if assure_tree(params, tmp_dir='clock_model_tmp'): return 1 dates = utils.parse_dates(params.dates) if len(dates)==0: return 1 outdir = get_outdir(params, '_clock') ########################################################################### ### READ IN VCF ########################################################################### #sets ref and fixed_pi to None if not VCF aln, ref, fixed_pi = read_if_vcf(params) is_vcf = True if ref is not None else False ########################################################################### ### ESTIMATE ROOT (if requested) AND DETERMINE TEMPORAL SIGNAL ########################################################################### if params.aln is None and params.sequence_length is None: print("one of arguments '--aln' and '--sequence-length' is required.", file=sys.stderr) return 1 basename = get_basename(params, outdir) myTree = TreeTime(dates=dates, tree=params.tree, aln=aln, gtr='JC69', verbose=params.verbose, seq_len=params.sequence_length, ref=ref) myTree.tip_slack=params.tip_slack if myTree.tree is None: print("ERROR: tree loading failed. exiting...") return 1 if params.clock_filter: n_bad = [n.name for n in myTree.tree.get_terminals() if n.bad_branch] myTree.clock_filter(n_iqd=params.clock_filter, reroot=params.reroot or 'least-squares') n_bad_after = [n.name for n in myTree.tree.get_terminals() if n.bad_branch] if len(n_bad_after)>len(n_bad): print("The following leaves don't follow a loose clock and " "will be ignored in rate estimation:\n\t" +"\n\t".join(set(n_bad_after).difference(n_bad))) if not params.keep_root: # reroot to optimal root, this assigns clock_model to myTree if params.covariation: # this requires branch length estimates myTree.run(root="least-squares", max_iter=0, use_covariation=params.covariation) res = myTree.reroot(params.reroot, force_positive=not params.allow_negative_rate) myTree.get_clock_model(covariation=params.covariation) if res==ttconf.ERROR: print("ERROR: unknown root or rooting mechanism!\n" "\tvalid choices are 'least-squares', 'ML', and 'ML-rough'") return 1 else: myTree.get_clock_model(covariation=params.covariation) d2d = utils.DateConversion.from_regression(myTree.clock_model) print('\n',d2d) print('The R^2 value indicates the fraction of variation in' '\nroot-to-tip distance explained by the sampling times.' '\nHigher values corresponds more clock-like behavior (max 1.0).') print('\nThe rate is the slope of the best fit of the date to' '\nthe root-to-tip distance and provides an estimate of' '\nthe substitution rate. The rate needs to be positive!' '\nNegative rates suggest an inappropriate root.\n') print('\nThe estimated rate and tree correspond to a root date:') if params.covariation: reg = myTree.clock_model dp = np.array([reg['intercept']/reg['slope']**2,-1./reg['slope']]) droot = np.sqrt(reg['cov'][:2,:2].dot(dp).dot(dp)) print('\n--- root-date:\t %3.2f +/- %1.2f (one std-dev)\n\n'%(-d2d.intercept/d2d.clock_rate, droot)) else: print('\n--- root-date:\t %3.2f\n\n'%(-d2d.intercept/d2d.clock_rate)) if not params.keep_root: # write rerooted tree to file outtree_name = basename+'rerooted.newick' Phylo.write(myTree.tree, outtree_name, 'newick') print("--- re-rooted tree written to \n\t%s\n"%outtree_name) table_fname = basename+'rtt.csv' with open(table_fname, 'w') as ofile: ofile.write("#name, date, root-to-tip distance\n") ofile.write("#Dates of nodes that didn't have a specified date are inferred from the root-to-tip regression.\n") for n in myTree.tree.get_terminals(): if hasattr(n, "raw_date_constraint") and (n.raw_date_constraint is not None): if np.isscalar(n.raw_date_constraint): tmp_str = str(n.raw_date_constraint) elif len(n.raw_date_constraint): tmp_str = str(n.raw_date_constraint[0])+'-'+str(n.raw_date_constraint[1]) else: tmp_str = '' ofile.write("%s, %s, %f\n"%(n.name, tmp_str, n.dist2root)) else: ofile.write("%s, %f, %f\n"%(n.name, d2d.numdate_from_dist2root(n.dist2root), n.dist2root)) for n in myTree.tree.get_nonterminals(order='preorder'): ofile.write("%s, %f, %f\n"%(n.name, d2d.numdate_from_dist2root(n.dist2root), n.dist2root)) print("--- wrote dates and root-to-tip distances to \n\t%s\n"%table_fname) ########################################################################### ### PLOT AND SAVE RESULT ########################################################################### plot_rtt(myTree, outdir+params.plot_rtt) return 0
def timetree(params): """ implementeing treetime tree """ if params.relax is None: relaxed_clock_params = None elif params.relax==[]: relaxed_clock_params=True elif len(params.relax)==2: relaxed_clock_params={'slack':params.relax[0], 'coupling':params.relax[1]} dates = utils.parse_dates(params.dates) if len(dates)==0: print("No valid dates -- exiting.") return 1 if assure_tree(params, tmp_dir='timetree_tmp'): print("No tree -- exiting.") return 1 outdir = get_outdir(params, '_treetime') gtr = create_gtr(params) infer_gtr = params.gtr=='infer' ########################################################################### ### READ IN VCF ########################################################################### #sets ref and fixed_pi to None if not VCF aln, ref, fixed_pi = read_if_vcf(params) is_vcf = True if ref is not None else False branch_length_mode = params.branch_length_mode #variable-site-only trees can have big branch lengths, the auto setting won't work. if is_vcf or (params.aln and params.sequence_length): if branch_length_mode == 'auto': branch_length_mode = 'joint' ########################################################################### ### SET-UP and RUN ########################################################################### if params.aln is None and params.sequence_length is None: print("one of arguments '--aln' and '--sequence-length' is required.", file=sys.stderr) return 1 myTree = TreeTime(dates=dates, tree=params.tree, ref=ref, aln=aln, gtr=gtr, seq_len=params.sequence_length, verbose=params.verbose) myTree.tip_slack=params.tip_slack if not myTree.one_mutation: print("TreeTime setup failed, exiting") return 1 # coalescent model options try: coalescent = float(params.coalescent) if coalescent<10*myTree.one_mutation: coalescent = None except: if params.coalescent in ['opt', 'const', 'skyline']: coalescent = params.coalescent else: print("unknown coalescent model specification, has to be either " "a float, 'opt', 'const' or 'skyline' -- exiting") return 1 # determine whether confidence intervals are to be computed and how the # uncertainty in the rate estimate should be treated calc_confidence = params.confidence if params.clock_std_dev: vary_rate = params.clock_std_dev if calc_confidence else False elif params.confidence and params.covariation: vary_rate = True elif params.confidence: print("\nOutside of covariance aware mode TreeTime cannot estimate confidence intervals " "without specified standard deviation of the clock rate Please specify '--clock-std-dev' " "or rerun with '--covariance'. Will proceed without confidence estimation") vary_rate = False calc_confidence = False else: vary_rate = False # RUN root = None if params.keep_root else params.reroot success = myTree.run(root=root, relaxed_clock=relaxed_clock_params, resolve_polytomies=(not params.keep_polytomies), Tc=coalescent, max_iter=params.max_iter, fixed_clock_rate=params.clock_rate, n_iqd=params.clock_filter, time_marginal="assign" if calc_confidence else False, vary_rate = vary_rate, branch_length_mode = branch_length_mode, fixed_pi=fixed_pi, use_covariation = params.covariation) if success==ttconf.ERROR: # if TreeTime.run failed, exit print("\nTreeTime run FAILED: please check above for errors and/or rerun with --verbose 4.\n") return 1 ########################################################################### ### OUTPUT and saving of results ########################################################################### if infer_gtr: print('\nInferred GTR model:') print(myTree.gtr) print(myTree.date2dist) basename = get_basename(params, outdir) if coalescent in ['skyline', 'opt', 'const']: print("Inferred coalescent model") if coalescent=='skyline': print_save_plot_skyline(myTree, plot=basename+'skyline.pdf', save=basename+'skyline.tsv', screen=True) else: Tc = myTree.merger_model.Tc.y[0] print(" --T_c: \t %1.2e \toptimized inverse merger rate in units of substitutions"%Tc) print(" --T_c: \t %1.2e \toptimized inverse merger rate in years"%(Tc/myTree.date2dist.clock_rate)) print(" --N_e: \t %1.2e \tcorresponding 'effective population size' assuming 50 gen/year\n"%(Tc/myTree.date2dist.clock_rate*50)) # plot import matplotlib.pyplot as plt from .treetime import plot_vs_years leaf_count = myTree.tree.count_terminals() label_func = lambda x: (x.name if x.is_terminal() and ((leaf_count<30 and (not params.no_tip_labels)) or params.tip_labels) else '') plot_vs_years(myTree, show_confidence=False, label_func=label_func, confidence=0.9 if params.confidence else None) tree_fname = (outdir + params.plot_tree) plt.savefig(tree_fname) print("--- saved tree as \n\t %s\n"%tree_fname) plot_rtt(myTree, outdir + params.plot_rtt) if params.relax: fname = outdir+'substitution_rates.tsv' print("--- wrote branch specific rates to\n\t %s\n"%fname) with open(fname, 'w') as fh: fh.write("#node\tclock_length\tmutation_length\trate\tfold_change\n") for n in myTree.tree.find_clades(order="preorder"): if n==myTree.tree.root: continue g = n.branch_length_interpolator.gamma fh.write("%s\t%1.3e\t%1.3e\t%1.3e\t%1.2f\n"%(n.name, n.clock_length, n.mutation_length, myTree.date2dist.clock_rate*g, g)) export_sequences_and_tree(myTree, basename, is_vcf, params.zero_based, timetree=True, confidence=calc_confidence) return 0
class tree(object): """tree builds a phylgenetic tree from an alignment and exports it for web visualization""" def __init__(self, aln, proteins=None, verbose=2, logger=None, **kwarks): super(tree, self).__init__() self.aln = aln # self.nthreads = 2 self.sequence_lookup = {seq.id:seq for seq in aln} self.nuc = kwarks['nuc'] if 'nuc' in kwarks else True self.dump_attr = [] # depreciated self.verbose = verbose if proteins!=None: self.proteins = proteins else: self.proteins={} if 'run_dir' not in kwarks: import random self.run_dir = '_'.join(['temp', time.strftime('%Y%m%d-%H%M%S',time.gmtime()), str(random.randint(0,1000000))]) else: self.run_dir = kwarks['run_dir'] if logger is None: def f(x,y): if y<self.verbose: print(x) self.logger = f else: self.logger=logger def getDateMin(self): return self.tree.root.date def getDateMax(self): dateMax = self.tree.root.date for node in self.tree.find_clades(): if node.date > dateMax: dateMax = node.date return dateMax def dump(self, treefile, nodefile): from Bio import Phylo Phylo.write(self.tree, treefile, 'newick') node_props = {} for node in self.tree.find_clades(): node_props[node.name] = {attr:node.__getattribute__(attr) for attr in self.dump_attr if hasattr(node, attr)} with myopen(nodefile, 'w') as nfile: pickle.dump(node_props, nfile) def check_newick(self, newick_file): try: tree = Phylo.parse(newick_file, 'newick').next() assert(set([x.name for x in tree.get_terminals()]) == set(self.sequence_lookup.keys())) return True except: return False def build_newick(self, newick_file, nthreads=2, method="raxml", raxml_options={}, iqtree_options={}, debug=False): make_dir(self.run_dir) os.chdir(self.run_dir) for seq in self.aln: seq.name=seq.id out_fname = os.path.join("..", newick_file) if method=="raxml": self.build_newick_raxml(out_fname, nthreads=nthreads, **raxml_options) elif method=="fasttree": self.build_newick_fasttree(out_fname) elif method=="iqtree": self.build_newick_iqtree(out_fname, **iqtree_options) os.chdir('..') self.logger("Saved new tree to %s"%out_fname, 1) if not debug: remove_dir(self.run_dir) def build_newick_fasttree(self, out_fname): from Bio import Phylo, AlignIO AlignIO.write(self.aln, 'temp.fasta', 'fasta') self.logger("Building tree with fasttree", 1) tree_cmd = ["fasttree"] if self.nuc: tree_cmd.append("-nt") tree_cmd.extend(["temp.fasta","1>",out_fname, "2>", "fasttree_stderr"]) os.system(" ".join(tree_cmd)) def build_newick_raxml(self, out_fname, nthreads=2, raxml_bin="raxml", num_distinct_starting_trees=1, **kwargs): from Bio import Phylo, AlignIO import shutil self.logger("modified RAxML script - no branch length optimisation or time limit", 1) AlignIO.write(self.aln,"temp.phyx", "phylip-relaxed") if num_distinct_starting_trees == 1: cmd = raxml_bin + " -f d -T " + str(nthreads) + " -m GTRCAT -c 25 -p 235813 -n tre -s temp.phyx" else: self.logger("RAxML running with {} starting trees (longer but better...)".format(num_distinct_starting_trees), 1) cmd = raxml_bin + " -f d -T " + str(nthreads) + " -N " + str(num_distinct_starting_trees) + " -m GTRCAT -c 25 -p 235813 -n tre -s temp.phyx" try: with open("raxml.log", 'w') as fh: check_call(cmd, stdout=fh, stderr=STDOUT, shell=True) self.logger("RAXML COMPLETED.", 1) except CalledProcessError: self.logger("RAXML TREE FAILED - check {}/raxml.log".format(self.run_dir), 1) raise shutil.copy('RAxML_bestTree.tre', out_fname) def build_newick_iqtree(self, out_fname, nthreads=2, iqtree_bin="iqtree", iqmodel="HKY", **kwargs): from Bio import Phylo, AlignIO import shutil self.logger("modified RAxML script - no branch length optimisation or time limit", 1) aln_file = "temp.fasta" AlignIO.write(self.aln, aln_file, "fasta") with open(aln_file) as ifile: tmp_seqs = ifile.readlines() with open(aln_file, 'w') as ofile: for line in tmp_seqs: ofile.write(line.replace('/', '_X_X_')) if iqmodel: call = ["iqtree", "-nt", str(nthreads), "-s", aln_file, "-m", iqmodel, "-fast", ">", "iqtree.log"] else: call = ["iqtree", "-nt", str(nthreads), "-s", aln_file, ">", "iqtree.log"] os.system(" ".join(call)) T = Phylo.read(aln_file+".treefile", 'newick') for n in T.get_terminals(): n.name = n.name.replace('_X_X_','/') Phylo.write(T,out_fname, 'newick') def tt_from_file(self, infile, root='best', nodefile=None): self.is_timetree=False self.logger('Reading tree from file '+infile,2) dates = {seq.id:seq.attributes['num_date'] for seq in self.aln if 'date' in seq.attributes} self.tt = TreeTime(dates=dates, tree=str(infile), gtr='Jukes-Cantor', aln = self.aln, verbose=self.verbose, fill_overhangs=True) if root: self.tt.reroot(root=root) self.tree = self.tt.tree for node in self.tree.find_clades(): if node.is_terminal() and node.name in self.sequence_lookup: seq = self.sequence_lookup[node.name] node.attr = seq.attributes try: node.attr['date'] = node.attr['date'].strftime('%Y-%m-%d') except: pass else: node.attr = {} if nodefile is not None: self.logger('reading node properties from file: '+nodefile,2) with myopen(nodefile, 'r') as infile: node_props = pickle.load(infile) for n in self.tree.find_clades(): if n.name in node_props: for attr in node_props[n.name]: n.__setattr__(attr, node_props[n.name][attr]) else: self.logger("No node properties found for "+n.name,2) def ancestral(self, **kwarks): self.tt.optimize_seq_and_branch_len(infer_gtr=True, **kwarks) self.dump_attr.append('sequence') for node in self.tree.find_clades(): if not hasattr(node,'attr'): node.attr = {} ## TODO REMOVE KWARKS - MAKE EXPLICIT def timetree(self, Tc=0.01, infer_gtr=True, reroot='best', resolve_polytomies=True, max_iter=2, confidence=False, use_marginal=False, **kwarks): self.logger('estimating time tree...',2) if confidence and use_marginal: marginal = 'assign' else: marginal = confidence self.tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal, resolve_polytomies=resolve_polytomies, max_iter=max_iter, **kwarks) self.logger('estimating time tree...done',3) self.dump_attr.extend(['numdate','date','sequence']) to_numdate = self.tt.date2dist.to_numdate for node in self.tree.find_clades(): if hasattr(node,'attr'): node.attr['num_date'] = node.numdate else: node.attr = {'num_date':node.numdate} if confidence: node.attr["num_date_confidence"] = sorted(self.tt.get_max_posterior_region(node, fraction=0.9)) self.is_timetree=True def save_timetree(self, fprefix, ttopts, cfopts): Phylo.write(self.tt.tree, fprefix+"_timetree.new", "newick") n = {} attrs = ["branch_length", "mutation_length", "clock_length", "dist2root", "name", "mutations", "attr", "cseq", "sequence", "numdate"] for node in self.tt.tree.find_clades(): n[node.name] = {} for x in attrs: n[node.name][x] = getattr(node, x) with open(fprefix+"_timetree.pickle", 'wb') as fh: pickle.dump({ "timetree_options": ttopts, "clock_filter_options": cfopts, "nodes": n, "original_seqs": list(self.sequence_lookup.keys()), }, fh, protocol=pickle.HIGHEST_PROTOCOL) def restore_timetree_node_info(self, nodes): for node in self.tt.tree.find_clades(): info = nodes[node.name] # print("restoring node info for node ", node.name) for k, v in info.items(): setattr(node, k, v) self.is_timetree=True def remove_outlier_clades(self, max_nodes=3, min_length=0.03): ''' check whether one child clade of the root is small and the connecting branch is long. if so, move the root up and reset a few tree props Args: max_nodes number of nodes beyond which the outliers are note removed min_length minimal length of the branch connecting the outlier clade to the rest of the tree to allow cutting. Returns: list of names of strains that have been removed ''' R = self.tt.tree.root if len(R.clades)>2: return num_child_nodes = np.array([c.count_terminals() for c in R]) putative_outlier = num_child_nodes.argmin() bl = np.sum([c.branch_length for c in R]) if (bl>min_length and num_child_nodes[putative_outlier]<max_nodes): if num_child_nodes[putative_outlier]==1: print("removing \"{}\" which is an outlier clade".format(R.clades[putative_outlier].name)) else: print("removing {} isolates which formed an outlier clade".format(num_child_nodes[putative_outlier])) self.tt.tree.root = R.clades[(1+putative_outlier)%2] self.tt.prepare_tree() return [c.name for c in R.clades[putative_outlier].get_terminals()] def geo_inference(self, attr, missing='?', root_state=None, report_confidence=False): ''' infer a "mugration" model by pretending each region corresponds to a sequence state and repurposing the GTR inference and ancestral reconstruction ''' from treetime import GTR # Determine alphabet places = set() for node in self.tree.find_clades(): if hasattr(node, 'attr'): if attr in node.attr and attr!=missing: places.add(node.attr[attr]) if root_state is not None: places.add(root_state) # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45) places = sorted(places) nc = len(places) if nc>180: self.logger("geo_inference: can't have more than 180 places!",1) return elif nc==1: self.logger("geo_inference: only one place found -- setting every internal node to %s!"%places[0],1) for node in self.tree.find_clades(): node.attr[attr] = places[0] node.__setattr__(attr+'_transitions',[]) return elif nc==0: self.logger("geo_inference: list of places is empty!",1) return # store previously reconstructed sequences nuc_seqs = {} nuc_muts = {} nuc_seq_LH = None if hasattr(self.tt.tree,'sequence_LH'): nuc_seq_LH = self.tt.tree.sequence_LH for node in self.tree.find_clades(): if hasattr(node, 'sequence'): nuc_seqs[node] = node.sequence if hasattr(node, 'mutations'): nuc_muts[node] = node.mutations node.__delattr__('mutations') alphabet = {chr(65+i):place for i,place in enumerate(places)} sequence_gtr = self.tt.gtr myGeoGTR = GTR.custom(pi = np.ones(nc, dtype=float)/nc, W=np.ones((nc,nc)), alphabet = np.array(sorted(alphabet.keys()))) missing_char = chr(65+nc) alphabet[missing_char]=missing myGeoGTR.profile_map[missing_char] = np.ones(nc) alphabet_rev = {v:k for k,v in alphabet.items()} # set geo info to nodes as one letter sequence. self.tt.seq_len = 1 for node in self.tree.get_terminals(): if hasattr(node, 'attr'): if attr in node.attr: node.sequence=np.array([alphabet_rev[node.attr[attr]]]) else: node.sequence=np.array([missing_char]) else: node.sequence=np.array([missing_char]) for node in self.tree.get_nonterminals(): node.__delattr__('sequence') if root_state is not None: self.tree.root.split(n=1, branch_length=0.0) extra_clade = self.tree.root.clades[-1] extra_clade.name = "dummy_root_node" extra_clade.up = self.tree.root extra_clade.sequence = np.array([alphabet_rev[root_state]]) self.tt.make_reduced_alignment() # set custom GTR model, run inference self.tt._gtr = myGeoGTR # import pdb; pdb.set_trace() tmp_use_mutation_length = self.tt.use_mutation_length self.tt.use_mutation_length=False self.tt.infer_ancestral_sequences(method='ml', infer_gtr=False, store_compressed=False, pc=5.0, marginal=True, normalized_rate=False) if root_state is not None: self.tree.prune(extra_clade) # restore the nucleotide sequence and mutations to maintain expected behavior self.tt.geogtr = self.tt.gtr self.tt.geogtr.alphabet_to_location = alphabet self.tt._gtr = sequence_gtr if hasattr(self.tt.tree,'sequence_LH'): self.tt.tree.geo_LH = self.tt.tree.sequence_LH self.tt.tree.sequence_LH = nuc_seq_LH for node in self.tree.find_clades(): node.attr[attr] = alphabet[node.sequence[0]] if node in nuc_seqs: node.sequence = nuc_seqs[node] if node.up is not None: node.__setattr__(attr+'_transitions', node.mutations) if node in nuc_muts: node.mutations = nuc_muts[node] # save marginal likelihoods if desired if report_confidence: node.attr[attr + "_entropy"] = sum([v * math.log(v+1E-20) for v in node.marginal_profile[0]]) * -1 / math.log(len(node.marginal_profile[0])) # javascript: vals.map((v) => v * Math.log(v + 1E-10)).reduce((a, b) => a + b, 0) * -1 / Math.log(vals.length); marginal = [(alphabet[self.tt.geogtr.alphabet[i]], node.marginal_profile[0][i]) for i in range(0, len(self.tt.geogtr.alphabet))] marginal.sort(key=lambda x: x[1], reverse=True) # sort on likelihoods marginal = [(a, b) for a, b in marginal if b > 0.01][:4] #only take stuff over 1% and the top 4 elements node.attr[attr + "_confidence"] = {a:b for a,b in marginal} self.tt.use_mutation_length=tmp_use_mutation_length # store saved attrs for save/restore functionality if not hasattr(self, "mugration_attrs"): self.mugration_attrs = [] self.mugration_attrs.append(attr) if report_confidence: self.mugration_attrs.extend([attr + "_entropy", attr + "_confidence"]) def restore_geo_inference(self, data, attr, confidence): if data == False: raise KeyError #yeah, not great for node in self.tree.find_clades(): node.attr[attr] = data[node.name][attr] if confidence: node.attr[attr+"_confidence"] = data[node.name][attr+"_confidence"] node.attr[attr+"_entropy"] = data[node.name][attr+"_entropy"] if not hasattr(self, "mugration_attrs"): self.mugration_attrs = [] self.mugration_attrs.append(attr) if confidence: self.mugration_attrs.extend([attr + "_entropy", attr + "_confidence"]) def get_attr_list(self, get_attr): states = [] for node in self.tree.find_clades(): if get_attr in node.attr: states.append(node.attr[get_attr]) return states def add_translations(self): ''' translate the nucleotide sequence into the proteins specified in self.proteins. these are expected to be SeqFeatures ''' from Bio import Seq # Sort proteins by start position of the corresponding SeqFeature entry. sorted_proteins = sorted(self.proteins.items(), key=lambda protein_pair: protein_pair[1].start) for node in self.tree.find_clades(order='preorder'): if not hasattr(node, "translations"): # Maintain genomic order of protein translations for easy # assembly by downstream functions. node.translations=OrderedDict() node.aa_mutations = {} for prot, feature in sorted_proteins: node.translations[prot] = Seq.translate(str(feature.extract(Seq.Seq("".join(node.sequence)))).replace('-', 'N')) if node.up is None: node.aa_mutations[prot] = [] else: node.aa_mutations[prot] = [(a,pos,d) for pos, (a,d) in enumerate(zip(node.up.translations[prot], node.translations[prot])) if a!=d] self.dump_attr.append('translations') def refine(self): ''' add attributes for export, currently this is only muts and aa_muts ''' self.tree.ladderize() for node in self.tree.find_clades(): if node.up is not None: node.muts = ["".join(map(str, [a, pos+1, d])) for a,pos,d in node.mutations if '-' not in [a,d]] # Sort all deletions by position to enable identification of # deletions >1 bp below. deletions = sorted( [(a,pos,d) for a,pos, d in node.mutations if '-' in [a,d]], key=lambda mutation: mutation[1] ) if len(deletions): length = 0 for pi, (a,pos,d) in enumerate(deletions[:-1]): if pos!=deletions[pi+1][1]-1: if length==0: node.muts.append(a+str(pos+1)+d) elif d=='-': node.muts.append("deletion %d-%d"%(pos-length, pos+1)) else: node.muts.append("insertion %d-%d"%(pos-length, pos+1)) else: length+=1 (a,pos,d) = deletions[-1] if length==0: node.muts.append(a+str(pos+1)+d) elif d=='-': node.muts.append("deletion %d-%d"%(pos-length, pos+1)) else: node.muts.append("insertion %d-%d"%(pos-length, pos+1)) node.aa_muts = {} if hasattr(node, 'translations'): for prot in node.translations: node.aa_muts[prot] = ["".join(map(str,[a,pos+1,d])) for a,pos,d in node.aa_mutations[prot]] for node in self.tree.find_clades(order="preorder"): if node.up is not None: #try: node.attr["div"] = node.up.attr["div"]+node.mutation_length else: node.attr["div"] = 0 self.dump_attr.extend(['muts', 'aa_muts', 'aa_mutations', 'mutation_length', 'mutations']) def layout(self): """Add clade, xvalue, yvalue, mutation and trunk attributes to all nodes in tree""" clade = 0 yvalue = self.tree.count_terminals() for node in self.tree.find_clades(order="preorder"): node.clade = clade clade += 1 if node.up is not None: #try: node.xvalue = node.up.xvalue+node.mutation_length if self.is_timetree: node.tvalue = node.numdate - self.tree.root.numdate else: node.tvalue = 0 else: node.xvalue = 0 node.tvalue = 0 if node.is_terminal(): node.yvalue = yvalue yvalue -= 1 for node in self.tree.get_nonterminals(order="postorder"): node.yvalue = np.mean([x.yvalue for x in node.clades]) self.dump_attr.extend(['yvalue', 'xvalue', 'clade']) if self.is_timetree: self.dump_attr.extend(['tvalue']) def export(self, path = '', extra_attr = ['aa_muts', 'clade'], plain_export = 10, indent=None, write_seqs_json=True): ''' export the tree data structure along with the sequence information as json files for display in web browsers. parameters: path -- path (incl prefix) to which the output files are written. filenames themselves are standardized to *tree.json and *sequences.json extra_attr -- attributes of tree nodes that are exported to json plain_export -- store sequences are plain strings instead of differences to root if number of differences exceeds len(seq)/plain_export ''' from Bio import Seq timetree_fname = path+'_tree.json' sequence_fname = path+'_sequences.json' tree_json = tree_to_json(self.tree.root, extra_attr=extra_attr) write_json(tree_json, timetree_fname, indent=indent) # prepare a json with sequence information to export. # first step: add the sequence & translations of the root as string elems = {} elems['root'] = {} elems['root']['nuc'] = "".join(self.tree.root.sequence) for prot,seq in self.tree.root.translations.items(): elems['root'][prot] = seq # add sequence for every node in tree. code as difference to root # or as full strings. for node in self.tree.find_clades(): if hasattr(node, "clade"): elems[node.clade] = {} # loop over proteins and nucleotide sequences for prot, seq in [('nuc', "".join(node.sequence))]+list(node.translations.items()): differences = {pos:state for pos, (state, ancstate) in enumerate(zip(seq, elems['root'][prot])) if state!=ancstate} if plain_export*len(differences)<=len(seq): elems[node.clade][prot] = differences else: elems[node.clade][prot] = seq if write_seqs_json: write_json(elems, sequence_fname, indent=indent)
axs[0].set_axis_off() axs[1].tick_params(labelsize=14) axs[1].set_ylabel("root-to-tip distance", fontsize=16) axs[1].set_xlabel("date", fontsize=16) fig.tight_layout() if __name__ == '__main__': # load data and parse dates plt.ion() base_name = 'data/H3N2_NA_allyears_NA.20' dates = read_dates(base_name) tt = TreeTime(gtr='Jukes-Cantor', tree=base_name + '.nwk', aln=base_name + '.fasta', verbose=1, dates=dates) # inititally the root if the tree is a mess: fig, axs = plt.subplots(1, 2, figsize=(18, 9)) axs[0].set_title("Arbitrarily rooted tree", fontsize=18) axs[1].set_title("Inverse divergence-time relationship", fontsize=18) Phylo.draw(tt.tree, show_confidence=False, axes=axs[0], label_func=lambda x: x.name.split('|')[0] if x.is_terminal() else "") tt.plot_root_to_tip(ax=axs[-1]) format_axes(fig, axs)
def refine(tree=None, aln=None, ref=None, dates=None, branch_length_inference='auto', confidence=False, resolve_polytomies=True, max_iter=2, precision='auto', infer_gtr=True, Tc=0.01, reroot=None, use_marginal=False, fixed_pi=None, clock_rate=None, clock_std=None, clock_filter_iqd=None, verbosity=1, covariance=True, **kwarks): from treetime import TreeTime try: #Tc could be a number or 'opt' or 'skyline'. TreeTime expects a float or int if a number. Tc = float(Tc) except ValueError: True #let it remain a string if (ref is not None) and (fixed_pi is None): #if VCF, fix pi #Otherwise mutation TO gaps is overestimated b/c of seq length fixed_pi = [ref.count(base)/len(ref) for base in ['A','C','G','T','-']] if fixed_pi[-1] == 0: fixed_pi[-1] = 0.05 fixed_pi = [v-0.01 for v in fixed_pi] if ref is not None: # VCF -> adjust branch length #set branch length mode explicitly if auto, as informative-site only #trees can have big branch lengths, making this set incorrectly in TreeTime if branch_length_inference == 'auto': branch_length_inference = 'joint' #send ref, if is None, does no harm tt = TreeTime(tree=tree, aln=aln, ref=ref, dates=dates, verbose=verbosity, gtr='JC69', precision=precision) # conditionally run clock-filter and remove bad tips if clock_filter_iqd: # treetime clock filter will mark, but not remove bad tips tt.clock_filter(reroot=reroot, n_iqd=clock_filter_iqd, plot=False) #use whatever was specified # remove them explicitly leaves = [x for x in tt.tree.get_terminals()] for n in leaves: if n.bad_branch: tt.tree.prune(n) print('pruning leaf ', n.name) # fix treetime set-up for new tree topology tt.prepare_tree() if confidence and use_marginal: # estimate confidence intervals via marginal ML and assign # marginal ML times to nodes marginal = 'assign' else: marginal = confidence # uncertainty of the the clock rate is relevant if confidence intervals are estimated if confidence and clock_std: vary_rate = clock_std # if standard devivation of clock is specified, use that elif confidence and covariance: vary_rate = True # if run in covariance mode, standard deviation can be estimated else: vary_rate = False # otherwise, rate uncertainty will be ignored tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal, branch_length_mode=branch_length_inference, resolve_polytomies=resolve_polytomies, max_iter=max_iter, fixed_pi=fixed_pi, fixed_clock_rate=clock_rate, vary_rate=vary_rate, use_covariation=covariance, **kwarks) if confidence: for n in tt.tree.find_clades(): n.num_date_confidence = list(tt.get_max_posterior_region(n, 0.9)) print("\nInferred a time resolved phylogeny using TreeTime:" "\n\tSagulenko et al. TreeTime: Maximum-likelihood phylodynamic analysis" "\n\tVirus Evolution, vol 4, https://academic.oup.com/ve/article/4/1/vex042/4794731\n") return tt
def timetree(params): """ implementeing treetime tree """ if params.relax is None: relaxed_clock_params = None elif params.relax==[]: relaxed_clock_params=True elif len(params.relax)==2: relaxed_clock_params={'slack':params.relax[0], 'coupling':params.relax[1]} dates = utils.parse_dates(params.dates) if len(dates)==0: print("No valid dates -- exiting.") return 1 if assure_tree(params, tmp_dir='timetree_tmp'): print("No tree -- exiting.") return 1 outdir = get_outdir(params, '_treetime') gtr = create_gtr(params) infer_gtr = params.gtr=='infer' ########################################################################### ### READ IN VCF ########################################################################### #sets ref and fixed_pi to None if not VCF aln, ref, fixed_pi = read_if_vcf(params) is_vcf = True if ref is not None else False branch_length_mode = params.branch_length_mode #variable-site-only trees can have big branch lengths, the auto setting won't work. if is_vcf or (params.aln and params.sequence_length): if branch_length_mode == 'auto': branch_length_mode = 'joint' ########################################################################### ### SET-UP and RUN ########################################################################### if params.aln is None and params.sequence_length is None: print("one of arguments '--aln' and '--sequence-length' is required.", file=sys.stderr) return 1 myTree = TreeTime(dates=dates, tree=params.tree, ref=ref, aln=aln, gtr=gtr, seq_len=params.sequence_length, verbose=params.verbose) myTree.tip_slack=params.tip_slack if not myTree.one_mutation: print("TreeTime setup failed, exiting") return 1 # coalescent model options try: coalescent = float(params.coalescent) if coalescent<10*myTree.one_mutation: coalescent = None except: if params.coalescent in ['opt', 'const', 'skyline']: coalescent = params.coalescent else: print("unknown coalescent model specification, has to be either " "a float, 'opt', 'const' or 'skyline' -- exiting") return 1 # determine whether confidence intervals are to be computed and how the # uncertainty in the rate estimate should be treated calc_confidence = params.confidence if params.clock_std_dev: vary_rate = params.clock_std_dev if calc_confidence else False elif params.confidence and params.covariation: vary_rate = True elif params.confidence: print("\nOutside of covariation aware mode TreeTime cannot estimate confidence intervals " "without specified standard deviation of the clock rate. \nPlease specify '--clock-std-dev' " "or rerun with '--covariation'. \nWill proceed without confidence estimation") vary_rate = False calc_confidence = False else: vary_rate = False # RUN root = None if params.keep_root else params.reroot success = myTree.run(root=root, relaxed_clock=relaxed_clock_params, resolve_polytomies=(not params.keep_polytomies), Tc=coalescent, max_iter=params.max_iter, fixed_clock_rate=params.clock_rate, n_iqd=params.clock_filter, time_marginal="assign" if calc_confidence else False, vary_rate = vary_rate, branch_length_mode = branch_length_mode, fixed_pi=fixed_pi, use_covariation = params.covariation) if success==ttconf.ERROR: # if TreeTime.run failed, exit print("\nTreeTime run FAILED: please check above for errors and/or rerun with --verbose 4.\n") return 1 ########################################################################### ### OUTPUT and saving of results ########################################################################### if infer_gtr: fname = outdir+'sequence_evolution_model.txt' with open(fname, 'w') as ofile: ofile.write(str(myTree.gtr)+'\n') print('\nInferred sequence evolution model (saved as %s):'%fname) print(myTree.gtr) fname = outdir+'molecular_clock.txt' with open(fname, 'w') as ofile: ofile.write(str(myTree.date2dist)+'\n') print('\nInferred sequence evolution model (saved as %s):'%fname) print(myTree.date2dist) basename = get_basename(params, outdir) if coalescent in ['skyline', 'opt', 'const']: print("Inferred coalescent model") if coalescent=='skyline': print_save_plot_skyline(myTree, plot=basename+'skyline.pdf', save=basename+'skyline.tsv', screen=True) else: Tc = myTree.merger_model.Tc.y[0] print(" --T_c: \t %1.2e \toptimized inverse merger rate in units of substitutions"%Tc) print(" --T_c: \t %1.2e \toptimized inverse merger rate in years"%(Tc/myTree.date2dist.clock_rate)) print(" --N_e: \t %1.2e \tcorresponding 'effective population size' assuming 50 gen/year\n"%(Tc/myTree.date2dist.clock_rate*50)) # plot import matplotlib.pyplot as plt from .treetime import plot_vs_years leaf_count = myTree.tree.count_terminals() label_func = lambda x: (x.name if x.is_terminal() and ((leaf_count<30 and (not params.no_tip_labels)) or params.tip_labels) else '') plot_vs_years(myTree, show_confidence=False, label_func=label_func, confidence=0.9 if params.confidence else None) tree_fname = (outdir + params.plot_tree) plt.savefig(tree_fname) print("--- saved tree as \n\t %s\n"%tree_fname) plot_rtt(myTree, outdir + params.plot_rtt) if params.relax: fname = outdir+'substitution_rates.tsv' print("--- wrote branch specific rates to\n\t %s\n"%fname) with open(fname, 'w') as fh: fh.write("#node\tclock_length\tmutation_length\trate\tfold_change\n") for n in myTree.tree.find_clades(order="preorder"): if n==myTree.tree.root: continue g = n.branch_length_interpolator.gamma fh.write("%s\t%1.3e\t%1.3e\t%1.3e\t%1.2f\n"%(n.name, n.clock_length, n.mutation_length, myTree.date2dist.clock_rate*g, g)) export_sequences_and_tree(myTree, basename, is_vcf, params.zero_based, timetree=True, confidence=calc_confidence) return 0
except: print("Seaborn not found. Default style will be used for the plots") from treetime import TreeTime from treetime.utils import parse_dates import treetime.config as ttconf if __name__ == '__main__': plt.ion() base_name = '../data/ebola/ebola' dates = parse_dates(base_name + '.metadata.csv') # instantiate treetime ebola = TreeTime(gtr='Jukes-Cantor', tree=base_name + '.nwk', precision=1, aln=base_name + '.fasta', verbose=2, dates=dates) # infer an ebola time tree while rerooting and resolving polytomies res = ebola.run(root='best', infer_gtr=True, relaxed_clock=False, max_iter=2, branch_length_mode='input', n_iqd=3, resolve_polytomies=True, Tc='skyline', time_marginal="assign") if res == ttconf.ERROR:
def timetree(params): """ implementeing treetime tree """ if params.relax == []: params.relax = True dates = utils.parse_dates(params.dates) if len(dates) == 0: return 1 if assure_tree(params, tmp_dir='timetree_tmp'): return 1 outdir = get_outdir(params, '_treetime') gtr = create_gtr(params) infer_gtr = params.gtr == 'infer' ########################################################################### ### READ IN VCF ########################################################################### #sets ref and fixed_pi to None if not VCF aln, ref, fixed_pi = read_if_vcf(params) is_vcf = True if ref is not None else False branch_length_mode = params.branch_length_mode if is_vcf: #variable-site-only trees can have big branch lengths, setting this wrong. if branch_length_mode == 'auto': branch_length_mode = 'joint' ########################################################################### ### SET-UP and RUN ########################################################################### if params.aln is None and params.sequence_length is None: print("one of arguments '--aln' and '--sequence-length' is required.", file=sys.stderr) return 1 myTree = TreeTime(dates=dates, tree=params.tree, ref=ref, aln=aln, gtr=gtr, seq_len=params.sequence_length, verbose=params.verbose) # coalescent model options try: coalescent = float(params.coalescent) if coalescent < 10 * myTree.one_mutation: coalescent = None except: if params.coalescent in ['opt', 'const', 'skyline']: coalescent = params.coalescent else: print("unknown coalescent model specification, has to be either " "a float, 'opt', 'const' or 'skyline'") coalescent = None vary_rate = params.confidence if params.clock_std_dev and params.clock_rate: vary_rate = params.clock_std_dev root = None if params.keep_root else params.reroot success = myTree.run( root=root, relaxed_clock=params.relax, resolve_polytomies=(not params.keep_polytomies), Tc=coalescent, max_iter=params.max_iter, fixed_clock_rate=params.clock_rate, n_iqd=params.clock_filter, time_marginal="assign" if params.confidence else False, vary_rate=vary_rate, branch_length_mode=branch_length_mode, fixed_pi=fixed_pi) if success == ttconf.ERROR: # if TreeTime.run failed, exit return 1 ########################################################################### ### OUTPUT and saving of results ########################################################################### if infer_gtr: print('\nInferred GTR model:') print(myTree.gtr) print(myTree.date2dist) basename = get_basename(params, outdir) if coalescent in ['skyline', 'opt']: print("Inferred coalescent model") if coalescent == 'skyline': print_save_plot_skyline(myTree, plot=basename + 'skyline.pdf', save=basename + 'skyline.tsv', screen=True) elif coalescent == 'opt': Tc = myTree.merger_model.Tc.y[0] print(" --T_c: \t %1.4f \toptimized inverse merger rate" % Tc) print( " --N_e: \t %1.1f \tcorresponding pop size assument 50 gen/year\n" % (Tc / myTree.date2dist.clock_rate * 50)) # plot import matplotlib.pyplot as plt from .treetime import plot_vs_years leaf_count = myTree.tree.count_terminals() label_func = lambda x: x.name[:20] if (leaf_count < 30 & x.is_terminal() ) else '' plot_vs_years(myTree, show_confidence=False, label_func=label_func, confidence=0.9 if params.confidence else None) tree_fname = (outdir + params.plot_tree) plt.savefig(tree_fname) print("--- saved tree as \n\t %s\n" % tree_fname) export_sequences_and_tree(myTree, basename, is_vcf, params.zero_based, timetree=True, confidence=params.confidence) plot_rtt(myTree, outdir + params.plot_rtt) return 0
class tree(object): """tree builds a phylgenetic tree from an alignment and exports it for web visualization""" def __init__(self, aln, proteins=None, verbose=2, logger=None, **kwarks): super(tree, self).__init__() self.aln = aln # self.nthreads = 2 self.sequence_lookup = {seq.id: seq for seq in aln} self.nuc = kwarks['nuc'] if 'nuc' in kwarks else True self.dump_attr = [] # depreciated self.verbose = verbose if proteins != None: self.proteins = proteins else: self.proteins = {} if 'run_dir' not in kwarks: import random self.run_dir = '_'.join([ 'temp', time.strftime('%Y%m%d-%H%M%S', time.gmtime()), str(random.randint(0, 1000000)) ]) else: self.run_dir = kwarks['run_dir'] if logger is None: def f(x, y): if y < self.verbose: print(x) self.logger = f else: self.logger = logger def getDateMin(self): return self.tree.root.date def getDateMax(self): dateMax = self.tree.root.date for node in self.tree.find_clades(): if node.date > dateMax: dateMax = node.date return dateMax def dump(self, treefile, nodefile): from Bio import Phylo Phylo.write(self.tree, treefile, 'newick') node_props = {} for node in self.tree.find_clades(): node_props[node.name] = { attr: node.__getattribute__(attr) for attr in self.dump_attr if hasattr(node, attr) } with myopen(nodefile, 'w') as nfile: pickle.dump(node_props, nfile) def check_newick(self, newick_file): try: tree = Phylo.parse(newick_file, 'newick').next() assert (set([x.name for x in tree.get_terminals() ]) == set(self.sequence_lookup.keys())) return True except: return False def build_newick(self, newick_file, nthreads=2, method="raxml", raxml_options={}, iqtree_options={}, debug=False): make_dir(self.run_dir) os.chdir(self.run_dir) for seq in self.aln: seq.name = seq.id out_fname = os.path.join("..", newick_file) if method == "raxml": self.build_newick_raxml(out_fname, nthreads=nthreads, **raxml_options) elif method == "fasttree": self.build_newick_fasttree(out_fname) elif method == "iqtree": self.build_newick_iqtree(out_fname, **iqtree_options) os.chdir('..') self.logger("Saved new tree to %s" % out_fname, 1) if not debug: remove_dir(self.run_dir) def build_newick_fasttree(self, out_fname): from Bio import Phylo, AlignIO AlignIO.write(self.aln, 'temp.fasta', 'fasta') self.logger("Building tree with fasttree", 1) tree_cmd = ["fasttree"] if self.nuc: tree_cmd.append("-nt") tree_cmd.extend( ["temp.fasta", "1>", out_fname, "2>", "fasttree_stderr"]) os.system(" ".join(tree_cmd)) def build_newick_raxml(self, out_fname, nthreads=2, raxml_bin="raxml", num_distinct_starting_trees=1, **kwargs): from Bio import Phylo, AlignIO import shutil self.logger( "modified RAxML script - no branch length optimisation or time limit", 1) AlignIO.write(self.aln, "temp.phyx", "phylip-relaxed") if num_distinct_starting_trees == 1: cmd = raxml_bin + " -f d -T " + str( nthreads) + " -m GTRCAT -c 25 -p 235813 -n tre -s temp.phyx" else: self.logger( "RAxML running with {} starting trees (longer but better...)". format(num_distinct_starting_trees), 1) cmd = raxml_bin + " -f d -T " + str(nthreads) + " -N " + str( num_distinct_starting_trees ) + " -m GTRCAT -c 25 -p 235813 -n tre -s temp.phyx" try: with open("raxml.log", 'w') as fh: check_call(cmd, stdout=fh, stderr=STDOUT, shell=True) self.logger("RAXML COMPLETED.", 1) except CalledProcessError: self.logger( "RAXML TREE FAILED - check {}/raxml.log".format(self.run_dir), 1) raise shutil.copy('RAxML_bestTree.tre', out_fname) def build_newick_iqtree(self, out_fname, nthreads=2, iqtree_bin="iqtree", iqmodel="HKY", **kwargs): from Bio import Phylo, AlignIO import shutil self.logger( "modified RAxML script - no branch length optimisation or time limit", 1) aln_file = "temp.fasta" AlignIO.write(self.aln, aln_file, "fasta") with open(aln_file) as ifile: tmp_seqs = ifile.readlines() with open(aln_file, 'w') as ofile: for line in tmp_seqs: ofile.write(line.replace('/', '_X_X_')) if iqmodel: call = [ "iqtree", "-nt", str(nthreads), "-s", aln_file, "-m", iqmodel, "-fast", ">", "iqtree.log" ] else: call = [ "iqtree", "-nt", str(nthreads), "-s", aln_file, ">", "iqtree.log" ] os.system(" ".join(call)) T = Phylo.read(aln_file + ".treefile", 'newick') for n in T.get_terminals(): n.name = n.name.replace('_X_X_', '/') Phylo.write(T, out_fname, 'newick') def tt_from_file(self, infile, root='best', nodefile=None): self.is_timetree = False self.logger('Reading tree from file ' + infile, 2) dates = { seq.id: seq.attributes['num_date'] for seq in self.aln if 'date' in seq.attributes } self.tt = TreeTime(dates=dates, tree=str(infile), gtr='Jukes-Cantor', aln=self.aln, verbose=self.verbose, fill_overhangs=True) if root: self.tt.reroot(root=root) self.tree = self.tt.tree for node in self.tree.find_clades(): if node.is_terminal() and node.name in self.sequence_lookup: seq = self.sequence_lookup[node.name] node.attr = seq.attributes try: node.attr['date'] = node.attr['date'].strftime('%Y-%m-%d') except: pass else: node.attr = {} if nodefile is not None: self.logger('reading node properties from file: ' + nodefile, 2) with myopen(nodefile, 'r') as infile: node_props = pickle.load(infile) for n in self.tree.find_clades(): if n.name in node_props: for attr in node_props[n.name]: n.__setattr__(attr, node_props[n.name][attr]) else: self.logger("No node properties found for " + n.name, 2) def ancestral(self, **kwarks): self.tt.optimize_seq_and_branch_len(infer_gtr=True, **kwarks) self.dump_attr.append('sequence') for node in self.tree.find_clades(): if not hasattr(node, 'attr'): node.attr = {} ## TODO REMOVE KWARKS - MAKE EXPLICIT def timetree(self, Tc=0.01, infer_gtr=True, reroot='best', resolve_polytomies=True, max_iter=2, confidence=False, use_marginal=False, **kwarks): self.logger('estimating time tree...', 2) if confidence and use_marginal: marginal = 'assign' else: marginal = confidence self.tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, time_marginal=marginal, resolve_polytomies=resolve_polytomies, max_iter=max_iter, **kwarks) self.logger('estimating time tree...done', 3) self.dump_attr.extend(['numdate', 'date', 'sequence']) to_numdate = self.tt.date2dist.to_numdate for node in self.tree.find_clades(): if hasattr(node, 'attr'): node.attr['num_date'] = node.numdate else: node.attr = {'num_date': node.numdate} if confidence: node.attr["num_date_confidence"] = sorted( self.tt.get_max_posterior_region(node, fraction=0.9)) self.is_timetree = True def save_timetree(self, fprefix, ttopts, cfopts): Phylo.write(self.tt.tree, fprefix + "_timetree.new", "newick") n = {} attrs = [ "branch_length", "mutation_length", "clock_length", "dist2root", "name", "mutations", "attr", "cseq", "sequence", "numdate" ] for node in self.tt.tree.find_clades(): n[node.name] = {} for x in attrs: n[node.name][x] = getattr(node, x) with open(fprefix + "_timetree.pickle", 'wb') as fh: pickle.dump( { "timetree_options": ttopts, "clock_filter_options": cfopts, "nodes": n, "original_seqs": list(self.sequence_lookup.keys()), }, fh, protocol=pickle.HIGHEST_PROTOCOL) def restore_timetree_node_info(self, nodes): for node in self.tt.tree.find_clades(): info = nodes[node.name] # print("restoring node info for node ", node.name) for k, v in info.items(): setattr(node, k, v) self.is_timetree = True def remove_outlier_clades(self, max_nodes=3, min_length=0.03): ''' check whether one child clade of the root is small and the connecting branch is long. if so, move the root up and reset a few tree props Args: max_nodes number of nodes beyond which the outliers are note removed min_length minimal length of the branch connecting the outlier clade to the rest of the tree to allow cutting. Returns: list of names of strains that have been removed ''' R = self.tt.tree.root if len(R.clades) > 2: return num_child_nodes = np.array([c.count_terminals() for c in R]) putative_outlier = num_child_nodes.argmin() bl = np.sum([c.branch_length for c in R]) if (bl > min_length and num_child_nodes[putative_outlier] < max_nodes): if num_child_nodes[putative_outlier] == 1: print("removing \"{}\" which is an outlier clade".format( R.clades[putative_outlier].name)) else: print("removing {} isolates which formed an outlier clade". format(num_child_nodes[putative_outlier])) self.tt.tree.root = R.clades[(1 + putative_outlier) % 2] self.tt.prepare_tree() return [c.name for c in R.clades[putative_outlier].get_terminals()] def geo_inference(self, attr, missing='?', root_state=None, report_confidence=False): ''' infer a "mugration" model by pretending each region corresponds to a sequence state and repurposing the GTR inference and ancestral reconstruction ''' from treetime import GTR # Determine alphabet places = set() for node in self.tree.find_clades(): if hasattr(node, 'attr'): if attr in node.attr and attr != missing: places.add(node.attr[attr]) if root_state is not None: places.add(root_state) # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45) places = sorted(places) nc = len(places) if nc > 180: self.logger("geo_inference: can't have more than 180 places!", 1) return elif nc == 1: self.logger( "geo_inference: only one place found -- setting every internal node to %s!" % places[0], 1) for node in self.tree.find_clades(): node.attr[attr] = places[0] node.__setattr__(attr + '_transitions', []) return elif nc == 0: self.logger("geo_inference: list of places is empty!", 1) return # store previously reconstructed sequences nuc_seqs = {} nuc_muts = {} nuc_seq_LH = None if hasattr(self.tt.tree, 'sequence_LH'): nuc_seq_LH = self.tt.tree.sequence_LH for node in self.tree.find_clades(): if hasattr(node, 'sequence'): nuc_seqs[node] = node.sequence if hasattr(node, 'mutations'): nuc_muts[node] = node.mutations node.__delattr__('mutations') alphabet = {chr(65 + i): place for i, place in enumerate(places)} sequence_gtr = self.tt.gtr myGeoGTR = GTR.custom(pi=np.ones(nc, dtype=float) / nc, W=np.ones((nc, nc)), alphabet=np.array(sorted(alphabet.keys()))) missing_char = chr(65 + nc) alphabet[missing_char] = missing myGeoGTR.profile_map[missing_char] = np.ones(nc) alphabet_rev = {v: k for k, v in alphabet.items()} # set geo info to nodes as one letter sequence. self.tt.seq_len = 1 for node in self.tree.get_terminals(): if hasattr(node, 'attr'): if attr in node.attr: node.sequence = np.array([alphabet_rev[node.attr[attr]]]) else: node.sequence = np.array([missing_char]) else: node.sequence = np.array([missing_char]) for node in self.tree.get_nonterminals(): node.__delattr__('sequence') if root_state is not None: self.tree.root.split(n=1, branch_length=0.0) extra_clade = self.tree.root.clades[-1] extra_clade.name = "dummy_root_node" extra_clade.up = self.tree.root extra_clade.sequence = np.array([alphabet_rev[root_state]]) self.tt.make_reduced_alignment() # set custom GTR model, run inference self.tt._gtr = myGeoGTR # import pdb; pdb.set_trace() tmp_use_mutation_length = self.tt.use_mutation_length self.tt.use_mutation_length = False self.tt.infer_ancestral_sequences(method='ml', infer_gtr=False, store_compressed=False, pc=5.0, marginal=True, normalized_rate=False) if root_state is not None: self.tree.prune(extra_clade) # restore the nucleotide sequence and mutations to maintain expected behavior self.tt.geogtr = self.tt.gtr self.tt.geogtr.alphabet_to_location = alphabet self.tt._gtr = sequence_gtr if hasattr(self.tt.tree, 'sequence_LH'): self.tt.tree.geo_LH = self.tt.tree.sequence_LH self.tt.tree.sequence_LH = nuc_seq_LH for node in self.tree.find_clades(): node.attr[attr] = alphabet[node.sequence[0]] if node in nuc_seqs: node.sequence = nuc_seqs[node] if node.up is not None: node.__setattr__(attr + '_transitions', node.mutations) if node in nuc_muts: node.mutations = nuc_muts[node] # save marginal likelihoods if desired if report_confidence: node.attr[attr + "_entropy"] = sum([ v * math.log(v + 1E-20) for v in node.marginal_profile[0] ]) * -1 / math.log(len(node.marginal_profile[0])) # javascript: vals.map((v) => v * Math.log(v + 1E-10)).reduce((a, b) => a + b, 0) * -1 / Math.log(vals.length); marginal = [(alphabet[self.tt.geogtr.alphabet[i]], node.marginal_profile[0][i]) for i in range(0, len(self.tt.geogtr.alphabet))] marginal.sort(key=lambda x: x[1], reverse=True) # sort on likelihoods marginal = [ (a, b) for a, b in marginal if b > 0.01 ][:4] #only take stuff over 1% and the top 4 elements node.attr[attr + "_confidence"] = {a: b for a, b in marginal} self.tt.use_mutation_length = tmp_use_mutation_length # store saved attrs for save/restore functionality if not hasattr(self, "mugration_attrs"): self.mugration_attrs = [] self.mugration_attrs.append(attr) if report_confidence: self.mugration_attrs.extend( [attr + "_entropy", attr + "_confidence"]) def restore_geo_inference(self, data, attr, confidence): if data == False: raise KeyError #yeah, not great for node in self.tree.find_clades(): node.attr[attr] = data[node.name][attr] if confidence: node.attr[attr + "_confidence"] = data[node.name][attr + "_confidence"] node.attr[attr + "_entropy"] = data[node.name][attr + "_entropy"] if not hasattr(self, "mugration_attrs"): self.mugration_attrs = [] self.mugration_attrs.append(attr) if confidence: self.mugration_attrs.extend( [attr + "_entropy", attr + "_confidence"]) def get_attr_list(self, get_attr): states = [] for node in self.tree.find_clades(): if get_attr in node.attr: states.append(node.attr[get_attr]) return states def add_translations(self): ''' translate the nucleotide sequence into the proteins specified in self.proteins. these are expected to be SeqFeatures ''' from Bio import Seq # Sort proteins by start position of the corresponding SeqFeature entry. sorted_proteins = sorted( self.proteins.items(), key=lambda protein_pair: protein_pair[1].start) for node in self.tree.find_clades(order='preorder'): if not hasattr(node, "translations"): # Maintain genomic order of protein translations for easy # assembly by downstream functions. node.translations = OrderedDict() node.aa_mutations = {} for prot, feature in sorted_proteins: node.translations[prot] = Seq.translate( str(feature.extract(Seq.Seq("".join( node.sequence)))).replace('-', 'N')) if node.up is None: node.aa_mutations[prot] = [] else: node.aa_mutations[prot] = [ (a, pos, d) for pos, (a, d) in enumerate( zip(node.up.translations[prot], node.translations[prot])) if a != d ] self.dump_attr.append('translations') def refine(self): ''' add attributes for export, currently this is only muts and aa_muts ''' self.tree.ladderize() for node in self.tree.find_clades(): if node.up is not None: node.muts = [ "".join(map(str, [a, pos + 1, d])) for a, pos, d in node.mutations if '-' not in [a, d] ] # Sort all deletions by position to enable identification of # deletions >1 bp below. deletions = sorted( [(a, pos, d) for a, pos, d in node.mutations if '-' in [a, d]], key=lambda mutation: mutation[1]) if len(deletions): length = 0 for pi, (a, pos, d) in enumerate(deletions[:-1]): if pos != deletions[pi + 1][1] - 1: if length == 0: node.muts.append(a + str(pos + 1) + d) elif d == '-': node.muts.append("deletion %d-%d" % (pos - length, pos + 1)) else: node.muts.append("insertion %d-%d" % (pos - length, pos + 1)) else: length += 1 (a, pos, d) = deletions[-1] if length == 0: node.muts.append(a + str(pos + 1) + d) elif d == '-': node.muts.append("deletion %d-%d" % (pos - length, pos + 1)) else: node.muts.append("insertion %d-%d" % (pos - length, pos + 1)) node.aa_muts = {} if hasattr(node, 'translations'): for prot in node.translations: node.aa_muts[prot] = [ "".join(map(str, [a, pos + 1, d])) for a, pos, d in node.aa_mutations[prot] ] for node in self.tree.find_clades(order="preorder"): if node.up is not None: #try: node.attr["div"] = node.up.attr["div"] + node.mutation_length else: node.attr["div"] = 0 self.dump_attr.extend([ 'muts', 'aa_muts', 'aa_mutations', 'mutation_length', 'mutations' ]) def layout(self): """Add clade, xvalue, yvalue, mutation and trunk attributes to all nodes in tree""" clade = 0 yvalue = self.tree.count_terminals() for node in self.tree.find_clades(order="preorder"): node.clade = clade clade += 1 if node.up is not None: #try: node.xvalue = node.up.xvalue + node.mutation_length if self.is_timetree: node.tvalue = node.numdate - self.tree.root.numdate else: node.tvalue = 0 else: node.xvalue = 0 node.tvalue = 0 if node.is_terminal(): node.yvalue = yvalue yvalue -= 1 for node in self.tree.get_nonterminals(order="postorder"): node.yvalue = np.mean([x.yvalue for x in node.clades]) self.dump_attr.extend(['yvalue', 'xvalue', 'clade']) if self.is_timetree: self.dump_attr.extend(['tvalue']) def export(self, path='', extra_attr=['aa_muts', 'clade'], plain_export=10, indent=None, write_seqs_json=True): ''' export the tree data structure along with the sequence information as json files for display in web browsers. parameters: path -- path (incl prefix) to which the output files are written. filenames themselves are standardized to *tree.json and *sequences.json extra_attr -- attributes of tree nodes that are exported to json plain_export -- store sequences are plain strings instead of differences to root if number of differences exceeds len(seq)/plain_export ''' from Bio import Seq timetree_fname = path + '_tree.json' sequence_fname = path + '_sequences.json' tree_json = tree_to_json(self.tree.root, extra_attr=extra_attr) write_json(tree_json, timetree_fname, indent=indent) # prepare a json with sequence information to export. # first step: add the sequence & translations of the root as string elems = {} elems['root'] = {} elems['root']['nuc'] = "".join(self.tree.root.sequence) for prot, seq in self.tree.root.translations.items(): elems['root'][prot] = seq # add sequence for every node in tree. code as difference to root # or as full strings. for node in self.tree.find_clades(): if hasattr(node, "clade"): elems[node.clade] = {} # loop over proteins and nucleotide sequences for prot, seq in [('nuc', "".join(node.sequence))] + list( node.translations.items()): differences = { pos: state for pos, (state, ancstate ) in enumerate(zip(seq, elems['root'][prot])) if state != ancstate } if plain_export * len(differences) <= len(seq): elems[node.clade][prot] = differences else: elems[node.clade][prot] = seq if write_seqs_json: write_json(elems, sequence_fname, indent=indent)
class tree(object): """tree builds a phylgenetic tree from an alignment and exports it for web visualization""" def __init__(self, aln, proteins=None, **kwarks): super(tree, self).__init__() self.aln = aln self.nthreads = 2 self.sequence_lookup = {seq.id: seq for seq in aln} self.nuc = kwarks['nuc'] if 'nuc' in kwarks else True self.dump_attr = [] if proteins != None: self.proteins = proteins else: self.proteins = {} if 'run_dir' not in kwarks: import random self.run_dir = '_'.join([ 'temp', time.strftime('%Y%m%d-%H%M%S', time.gmtime()), str(random.randint(0, 1000000)) ]) else: self.run_dir = kwarks['run_dir'] def dump(self, treefile, nodefile): from Bio import Phylo Phylo.write(self.tree, treefile, 'newick') node_props = {} for node in self.tree.find_clades(): node_props[node.name] = { attr: node.__getattribute__(attr) for attr in self.dump_attr if hasattr(node, attr) } with myopen(nodefile, 'w') as nfile: from cPickle import dump dump(node_props, nfile) def build(self, root='midpoint', raxml=True, raxml_time_limit=0.5, raxml_bin='raxml'): from Bio import Phylo, AlignIO import subprocess, glob, shutil make_dir(self.run_dir) os.chdir(self.run_dir) for seq in self.aln: seq.name = seq.id AlignIO.write(self.aln, 'temp.fasta', 'fasta') tree_cmd = ["fasttree"] if self.nuc: tree_cmd.append("-nt") tree_cmd.append("temp.fasta") tree_cmd.append(">") tree_cmd.append("initial_tree.newick") os.system(" ".join(tree_cmd)) out_fname = "tree_infer.newick" if raxml: if raxml_time_limit > 0: tmp_tree = Phylo.read('initial_tree.newick', 'newick') resolve_iter = 0 resolve_polytomies(tmp_tree) while (not tmp_tree.is_bifurcating()) and (resolve_iter < 10): resolve_iter += 1 resolve_polytomies(tmp_tree) Phylo.write(tmp_tree, 'initial_tree.newick', 'newick') AlignIO.write(self.aln, "temp.phyx", "phylip-relaxed") print("RAxML tree optimization with time limit", raxml_time_limit, "hours") # using exec to be able to kill process end_time = time.time() + int(raxml_time_limit * 3600) process = subprocess.Popen( "exec " + raxml_bin + " -f d -T " + str(self.nthreads) + " -j -s temp.phyx -n topology -c 25 -m GTRCAT -p 344312987 -t initial_tree.newick", shell=True) while (time.time() < end_time): if os.path.isfile('RAxML_result.topology'): break time.sleep(10) process.terminate() checkpoint_files = glob.glob("RAxML_checkpoint*") if os.path.isfile('RAxML_result.topology'): checkpoint_files.append('RAxML_result.topology') if len(checkpoint_files) > 0: last_tree_file = checkpoint_files[-1] shutil.copy(last_tree_file, 'raxml_tree.newick') else: shutil.copy("initial_tree.newick", 'raxml_tree.newick') else: shutil.copy("initial_tree.newick", 'raxml_tree.newick') try: print("RAxML branch length optimization") os.system( raxml_bin + " -f e -T " + str(self.nthreads) + " -s temp.phyx -n branches -c 25 -m GTRGAMMA -p 344312987 -t raxml_tree.newick" ) shutil.copy('RAxML_result.branches', out_fname) except: print("RAxML branch length optimization failed") shutil.copy('raxml_tree.newick', out_fname) else: shutil.copy('initial_tree.newick', out_fname) self.tt_from_file(out_fname, root) os.chdir('..') remove_dir(self.run_dir) self.is_timetree = False def tt_from_file(self, infile, root='best', nodefile=None): from treetime import TreeTime from treetime import utils print('Reading tree from file', infile) dates = { seq.id: seq.attributes['num_date'] for seq in self.aln if 'date' in seq.attributes } self.tt = TreeTime(dates=dates, tree=infile, gtr='Jukes-Cantor', aln=self.aln, verbose=4) self.tt.reroot(root=root) self.tree = self.tt.tree for node in self.tree.find_clades(): if node.is_terminal() and node.name in self.sequence_lookup: seq = self.sequence_lookup[node.name] node.attr = seq.attributes try: node.attr['date'] = node.attr['date'].strftime('%Y-%m-%d') except: pass else: node.attr = {} if nodefile is not None: print('reading node properties from file:', nodefile) with myopen(nodefile, 'r') as infile: from cPickle import load node_props = load(infile) for n in self.tree.find_clades(): if n.name in node_props: for attr in node_props[n.name]: n.__setattr__(attr, node_props[n.name][attr]) else: print("No node properties found for ", n.name) def ancestral(self): self.tt.optimize_seq_and_branch_len(infer_gtr=True) self.dump_attr.append('sequence') def timetree(self, Tc=0.01, infer_gtr=True, reroot='best', resolve_polytomies=True, max_iter=2, **kwarks): self.tt.run(infer_gtr=infer_gtr, root=reroot, Tc=Tc, resolve_polytomies=resolve_polytomies, max_iter=max_iter) print('estimating time tree...') self.dump_attr.extend(['numdate', 'date', 'sequence']) for node in self.tree.find_clades(): if hasattr(node, 'attr'): node.attr['num_date'] = node.numdate else: node.attr = {'num_date': node.numdate} self.is_timetree = True def geo_inference(self, attr): ''' infer a "mugration" model by pretending each region corresponds to a sequence state and repurposing the GTR inference and ancestral reconstruction ''' from treetime.gtr import GTR # Determine alphabet and store reconstructed ancestral sequences places = set() nuc_seqs = {} nuc_muts = {} for node in self.tree.find_clades(): if hasattr(node, 'attr'): if attr in node.attr: places.add(node.attr[attr]) if hasattr(node, 'sequence'): nuc_seqs[node] = node.sequence if hasattr(node, 'mutations'): nuc_muts[node] = node.mutations node.__delattr__('mutations') # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45) places = sorted(places) nc = len(places) if nc < 2 or nc > 180: print( "geo_inference: can't have less than 2 or more than 180 places!" ) return alphabet = {chr(65 + i): place for i, place in enumerate(places)} alphabet_rev = {v: k for k, v in alphabet.iteritems()} sequence_gtr = self.tt.gtr myGeoGTR = GTR.custom(pi=np.ones(nc, dtype=float) / nc, W=np.ones((nc, nc)), alphabet=np.array(sorted(alphabet.keys()))) myGeoGTR.profile_map['-'] = np.ones(nc) # set geo info to nodes as one letter sequence. for node in self.tree.get_terminals(): if hasattr(node, 'attr'): if attr in node.attr: node.sequence = np.array([alphabet_rev[node.attr[attr]]]) else: node.sequence = np.array(['-']) for node in self.tree.get_nonterminals(): node.__delattr__('sequence') # set custom GTR model, run inference self.tt._gtr = myGeoGTR tmp_use_mutation_length = self.tt.use_mutation_length self.tt.use_mutation_length = False self.tt.infer_ancestral_sequences(method='ml', infer_gtr=True, store_compressed=False, pc=5.0, marginal=True) # restore the nucleotide sequence and mutations to maintain expected behavior self.tt.geogtr = self.tt.gtr self.tt.geogtr.alphabet_to_location = alphabet self.tt._gtr = sequence_gtr self.dump_attr.append(attr) for node in self.tree.find_clades(): node.attr[attr] = alphabet[node.sequence[0]] if node in nuc_seqs: node.sequence = nuc_seqs[node] if node.up is not None: node.__setattr__(attr + '_transitions', node.mutations) if node in nuc_muts: node.mutations = nuc_muts[node] self.tt.use_mutation_length = tmp_use_mutation_length def add_translations(self): ''' translate the nucleotide sequence into the proteins specified in self.proteins. these are expected to be SeqFeatures ''' from Bio import Seq for node in self.tree.find_clades(order='preorder'): if not hasattr(node, "translations"): node.translations = {} node.aa_mutations = {} if node.up is None: for prot in self.proteins: node.translations[prot] = Seq.translate( str(self.proteins[prot].extract( Seq.Seq("".join(node.sequence)))).replace( '-', 'N')) node.aa_mutations[prot] = [] else: for prot in self.proteins: node.translations[prot] = Seq.translate( str(self.proteins[prot].extract( Seq.Seq("".join(node.sequence)))).replace( '-', 'N')) node.aa_mutations[prot] = [ (a, pos, d) for pos, (a, d) in enumerate( zip(node.up.translations[prot], node.translations[prot])) if a != d ] self.dump_attr.append('translations') def refine(self): ''' add attributes for export, currently this is only muts and aa_muts ''' self.tree.ladderize() for node in self.tree.find_clades(): if node.up is not None: node.muts = [ "".join(map(str, [a, pos + 1, d])) for a, pos, d in node.mutations ] node.aa_muts = {} if hasattr(node, 'translations'): for prot in node.translations: node.aa_muts[prot] = [ "".join(map(str, [a, pos + 1, d])) for a, pos, d in node.aa_mutations[prot] ] for node in self.tree.find_clades(order="preorder"): if node.up is not None: #try: node.attr["div"] = node.up.attr["div"] + node.mutation_length else: node.attr["div"] = 0 self.dump_attr.extend([ 'muts', 'aa_muts', 'aa_mutations', 'mutation_length', 'mutations' ]) def layout(self): """Add clade, xvalue, yvalue, mutation and trunk attributes to all nodes in tree""" clade = 0 yvalue = self.tree.count_terminals() for node in self.tree.find_clades(order="preorder"): node.clade = clade clade += 1 if node.up is not None: #try: node.xvalue = node.up.xvalue + node.mutation_length if self.is_timetree: node.tvalue = node.numdate - self.tree.root.numdate else: node.tvalue = 0 else: node.xvalue = 0 node.tvalue = 0 if node.is_terminal(): node.yvalue = yvalue yvalue -= 1 for node in self.tree.get_nonterminals(order="postorder"): node.yvalue = np.mean([x.yvalue for x in node.clades]) self.dump_attr.extend(['yvalue', 'xvalue', 'clade']) if self.is_timetree: self.dump_attr.extend(['tvalue']) def export(self, path='', extra_attr=['aa_muts', 'clade'], plain_export=10): ''' export the tree data structure along with the sequence information as json files for display in web browsers. parameters: path -- path (incl prefix) to which the output files are written. filenames themselves are standardized to *tree.json and *sequences.json extra_attr -- attributes of tree nodes that are exported to json plain_export -- store sequences are plain strings instead of differences to root if number of differences exceeds len(seq)/plain_export ''' from Bio import Seq from itertools import izip timetree_fname = path + 'tree.json' sequence_fname = path + 'sequences.json' tree_json = tree_to_json(self.tree.root, extra_attr=extra_attr) write_json(tree_json, timetree_fname, indent=None) # prepare a json with sequence information to export. # first step: add the sequence & translations of the root as string elems = {} elems['root'] = {} elems['root']['nuc'] = "".join(self.tree.root.sequence) for prot, seq in self.tree.root.translations.iteritems(): elems['root'][prot] = seq # add sequence for every node in tree. code as difference to root # or as full strings. for node in self.tree.find_clades(): if hasattr(node, "clade"): elems[node.clade] = {} # loop over proteins and nucleotide sequences for prot, seq in [('nuc', "".join(node.sequence)) ] + node.translations.items(): differences = { pos: state for pos, (state, ancstate) in enumerate( izip(seq, elems['root'][prot])) if state != ancstate } if plain_export * len(differences) <= len(seq): elems[node.clade][prot] = differences else: elems[node.clade][prot] = seq write_json(elems, sequence_fname, indent=None)