def _build(parse_node, tree_node=None): if tree_node is None: tree_node = TreeNode() if isinstance(parse_node, list): print(parse_node) if isinstance(parse_node, LeafNode): symbol = parse_node.literal token = parse_node.token tree_node.name = symbol tree_node.add_feature("tokens", [token]) elif isinstance(parse_node, InternalNode): symbol = parse_node.symbol rule = parse_node.rule children = parse_node.children token = [] for child_node in children: node = _build(child_node) tree_node.add_child(node) token.extend(node.tokens) tree_node.name = symbol tree_node.add_feature("rule", rule) tree_node.add_feature("tokens", token) return tree_node
def addDeadLineage(spTree): """ Takes: - spTree (ete3.Tree) : species tree Returns: (ete3.Tree) : same tree with a dead lineage (name "-1") as outgroup AND all nodes have a "dead" feature (bool that is True only for the dead lineage and the new root) """ newSpTree = deepcopy(spTree) newSpTree.dist = 0.1 for n in newSpTree.traverse(): n.add_feature("dead",False) newRoot = TreeNode() newRoot.add_feature("dead",True) newRoot.dist = 0.0 newRoot.add_child(newSpTree) rootHeight = newRoot.get_distance(newRoot.get_leaves()[0]) deadLineage = TreeNode() deadLineage.add_feature("dead",True) deadLineage.name = "-1" deadLineage.dist = rootHeight newRoot.add_child(deadLineage) return newRoot
def p_toArbre(self): n = TreeNode() n.name = "main()" n1 = TreeNode() n1.name = str(self.sons[0]) n2 = self.sons[1].c_toArbre() n3 = self.sons[2].e_toArbre() n.add_child(n1) n.add_child(n2) n.add_child(n3) return n
def e_toArbre(self): if self.type == "NUMBER": n = TreeNode() n.name = "Number : " + str(self.value) return n elif self.type == "ID": n = TreeNode() n.name = "Id : " + self.value return n elif self.type == "OPBIN": n = TreeNode() n.name = self.value n1 = self.sons[0].e_toArbre() n2 = self.sons[1].e_toArbre() n.add_child(n1) n.add_child(n2) return n
def c_toArbre(self): if self.value == "=": n = TreeNode() n.name = self.value n1 = TreeNode() n1.name = "Id : " + self.sons[0] n2 = self.sons[1].e_toArbre() n.add_child(n1) n.add_child(n2) return n elif self.value == ';': n = TreeNode() n.name = self.value n1 = self.sons[0].c_toArbre() n2 = self.sons[1].c_toArbre() n.add_child(n1) n.add_child(n2) return n else: n = TreeNode() n.name = self.value n1 = self.sons[0].e_toArbre() n2 = self.sons[1].c_toArbre() n.add_child(n1) n.add_child(n2) return n
def add_tree_to_distribution(self, tree): """ Add the bipartition of a tree to the CCP distribution Takes: - tree (ete3.Tree): phylogenetic tree """ if len(tree.children) == 3: ## special unrroted case where the tree begin by a trifurcation ... ## we artificially remove the trifurcation to avoid future problems a = TreeNode() b = tree.children[1] c = tree.children[2] b.detach() c.detach() tree.add_child(a) a.add_child(b) a.add_child(c) #print " special rerooting " for i in tree.traverse(): if len(i.children) > 2: print "multifurcation detected! Please provide bifurcating trees." print "exiting now" exit(1) if self.nb_observation == 0: ##no tree has been observed yet: add all the leaves for l in tree.get_leaf_names(): self.get_leaf_id(l) ##adds the leaves to the CCP for node in tree.traverse("postorder"): ##for each branch of the tree self.add_tree_branch_to_distribution(node) self.nb_observation += 1 return
def subdivideSpTree(spTree): """ Takes: - spTree (ete3.Tree) : an ULTRAMETRIC species tree Returns: (ete3.Tree) : subdivided species tree where all nodes have a timeSlice feature or None if the species tree is not ultrametric """ newSpTree = deepcopy(spTree) featureName = "timeSlice" ##1/ getting distance from root. Dheight = getDistFromRootDic(newSpTree , checkUltrametric = True) if Dheight is None: print "!!ERROR!! : the species tree is not ultrametric" return None # we know that there is n-1 internal nodes (where n is the number of leaves) # hence the maximal timeSlice is n-1 (all leaves have timeSlice 0) ##2/assign timeSlice to nodes currentTS = len(newSpTree.get_leaves()) - 1 for n,h in sorted(Dheight.iteritems(), key=lambda (k,v): (v,k)): n.add_feature(featureName, currentTS ) if currentTS != 0: currentTS -= 1 #print newSpTree.get_ascii(attributes=[featureName,"name"]) ##3/subdivide according to timeSlice RealNodes = [i for i in newSpTree.traverse()] for n in RealNodes: if n.is_root(): continue nodeToAdd = n.up.timeSlice - n.timeSlice - 1 while nodeToAdd > 0: parentNode = n.up n.detach() NullNode = TreeNode() NullNode.add_feature( featureName, parentNode.timeSlice - 1 ) if "dead" in n.features: NullNode.add_feature("dead" , n.dead) parentNode.add_child(NullNode) NullNode.add_child(n) nodeToAdd -= 1 #print newSpTree.get_ascii(attributes=[featureName,"name"]) return newSpTree
if parent in nodes: #parent = G.search_nodes(name = _parent) parentNode = nodes[parent] else: parentNode = TreeNode(name=parent) parentNode.set_style(nstyle) #parentNode.add_face(TextFace(_parent), column=0, position="aligned") #faces.add_face_to_node(TextFace(parent), parentNode, 0, position="aligned") nodes[parent] = parentNode #child = G.search_nodes(name = node) if node in nodes: childNode = nodes[node] else: childNode = parentNode.add_child(name=node) childNode.set_style(nstyle) #childNode.add_face(TextFace(node), column=0, position="aligned") #faces.add_face_to_node(TextFace(node), childNode, 0, position="aligned") nodes[node] = childNode print(G) #print(nodes) for n in G.traverse(): if (len(n.get_ancestors()) >= 4): n.dist = 0.1 else: n.dist = 1.0 ts = TreeStyle()
class CollapsedTree(LeavesAndClades): ''' Here's a derived class for a collapsed tree, where we recurse into the mutant clades (4) / | \\ (3)(1)(2) | \\ (2) (1) ''' def __init__(self, params=None, tree=None, frame=None, collapse_syn=False, allow_repeats=False): ''' For intialization, either params or tree (or both) must be provided params: offspring distribution parameters tree: ete tree with frequency node feature. If uncollapsed, it will be collapsed frame: tranlation frame, with default None, no tranlation attempted ''' LeavesAndClades.__init__(self, params=params) if frame is not None and frame not in (1, 2, 3): raise RuntimeError('frame must be 1, 2, 3, or None') self.frame = frame if collapse_syn is True: tree.dist = 0 # no branch above root for node in tree.iter_descendants(): aa = Seq( node.sequence[(frame - 1):(frame - 1 + (3 * (((len(node.sequence) - (frame - 1)) // 3))))], generic_dna).translate() aa_parent = Seq( node.up.sequence[(frame - 1):(frame - 1 + (3 * (((len(node.sequence) - (frame - 1)) // 3))))], generic_dna).translate() node.dist = hamming_distance(aa, aa_parent) if tree is not None: self.tree = tree.copy() # remove unobserved internal unifurcations for node in self.tree.iter_descendants(): parent = node.up if node.frequency == 0 and len(node.children) == 1: node.delete(prevent_nondicotomic=False) node.children[0].dist = hamming_distance( node.children[0].sequence, parent.sequence) # iterate over the tree below root and collapse edges of zero length # if the node is a leaf and it's parent has nonzero frequency we combine taxa names to a set # this acommodates bootstrap samples that result in repeated genotypes observed_genotypes = set((leaf.name for leaf in self.tree)) observed_genotypes.add(self.tree.name) for node in self.tree.get_descendants(strategy='postorder'): if node.dist == 0: node.up.frequency += node.frequency node_set = set([node.name]) if isinstance( node.name, str) else set(node.name) node_up_set = set([node.up.name]) if isinstance( node.up.name, str) else set(node.up.name) if node_up_set < observed_genotypes: if node_set < observed_genotypes: node.up.name = tuple(node_set | node_up_set) if len(node.up.name) == 1: node.up.name = node.up.name[0] elif node_set < observed_genotypes: node.up.name = tuple(node_set) if len(node.up.name) == 1: node.up.name = node.up.name[0] node.delete(prevent_nondicotomic=False) final_observed_genotypes = set([ name for node in self.tree.traverse() if node.frequency > 0 or node == self.tree for name in (( node.name, ) if isinstance(node.name, str) else node.name) ]) if final_observed_genotypes != observed_genotypes: raise RuntimeError( 'observed genotypes don\'t match after collapse\n\tbefore: {}\n\tafter: {}\n\tsymmetric diff: {}' .format(observed_genotypes, final_observed_genotypes, observed_genotypes ^ final_observed_genotypes)) assert sum(node.frequency for node in tree.traverse()) == sum( node.frequency for node in self.tree.traverse()) rep_seq = sum( node.frequency > 0 for node in self.tree.traverse()) - len( set([ node.sequence for node in self.tree.traverse() if node.frequency > 0 ])) if not allow_repeats and rep_seq: raise RuntimeError( 'Repeated observed sequences in collapsed tree. {} sequences were found repeated.' .format(rep_seq)) elif allow_repeats and rep_seq: rep_seq = sum(node.frequency > 0 for node in self.tree.traverse()) - len( set([ node.sequence for node in self.tree.traverse() if node.frequency > 0 ])) print( 'Repeated observed sequences in collapsed tree. {} sequences were found repeated.' .format(rep_seq)) # a custom ladderize accounting for abundance and sequence to break ties in abundance for node in self.tree.traverse(strategy='postorder'): # add a partition feature and compute it recursively up the tree node.add_feature( 'partition', node.frequency + sum(node2.partition for node2 in node.children)) # sort children of this node based on partion and sequence node.children.sort( key=lambda node: (node.partition, node.sequence)) else: self.tree = tree def l(self, params, sign=1): ''' log likelihood of params, conditioned on collapsed tree, and its gradient wrt params optional parameter sign must be 1 or -1, with the latter useful for MLE by minimization ''' if self.tree is None: raise ValueError('tree data must be defined to compute likelihood') if sign not in (-1, 1): raise ValueError('sign must be 1 or -1') leaves_and_clades_list = [ LeavesAndClades(c=node.frequency, m=len(node.children)) for node in self.tree.traverse() ] if leaves_and_clades_list[0].c == 0 and leaves_and_clades_list[ 0].m == 1 and leaves_and_clades_list[0].f(params)[0] == 0: # if unifurcation not possible under current model, add a psuedocount for the naive leaves_and_clades_list[0].c = 1 # extract vector of function values and gradient components f_data = [ leaves_and_clades.f(params) for leaves_and_clades in leaves_and_clades_list ] fs = scipy.array([[x[0]] for x in f_data]) logf = scipy.log(fs).sum() grad_fs = scipy.array([x[1] for x in f_data]) grad_logf = (grad_fs / fs).sum(axis=0) return sign * logf, sign * grad_logf def mle(self, **kwargs): ''' Maximum likelihood estimate for params given tree updates params if not None returns optimization result ''' # random initalization x_0 = (random.random(), random.random()) bounds = ((.01, .99), (.001, .999)) kwargs['sign'] = -1 grad_check = check_grad(lambda x: self.l(x, **kwargs)[0], lambda x: self.l(x, **kwargs)[1], (.4, .5)) if grad_check > 1e-3: warnings.warn( 'gradient mismatches finite difference approximation by {}'. format(grad_check), RuntimeWarning) result = minimize(lambda x: self.l(x, **kwargs), x0=x_0, jac=True, method='L-BFGS-B', options={'ftol': 1e-10}, bounds=bounds) # update params if None and optimization successful if not result.success: warnings.warn('optimization not sucessful, ' + result.message, RuntimeWarning) elif self.params is None: self.params = result.x.tolist() return result def simulate(self): ''' simulate a collapsed tree given params replaces existing tree data member with simulation result, and returns self ''' if self.params is None: raise ValueError('params must be defined for simulation') # initiate by running a LeavesAndClades simulation to get the number of clones and mutants # in the root node of the collapsed tree LeavesAndClades.simulate(self) self.tree = TreeNode() self.tree.add_feature('frequency', self.c) if self.m == 0: return self for _ in range(self.m): # ooooh, recursion child = CollapsedTree(params=self.params, frame=self.frame).simulate().tree child.dist = 1 self.tree.add_child(child) return self def __str__(self): '''return a string representation for printing''' return 'params = ' + str(self.params) + '\ntree:\n' + str(self.tree) def render(self, outfile, idlabel=False, colormap=None, show_support=False, chain_split=None): '''render to image file, filetype inferred from suffix, svg for color images''' def my_layout(node): circle_color = 'lightgray' if colormap is None or node.name not in colormap else colormap[ node.name] text_color = 'black' if isinstance(circle_color, str): C = CircleFace(radius=max(3, 10 * scipy.sqrt(node.frequency)), color=circle_color, label={ 'text': str(node.frequency), 'color': text_color } if node.frequency > 0 else None) C.rotation = -90 C.hz_align = 1 faces.add_face_to_node(C, node, 0) else: P = PieChartFace( [100 * x / node.frequency for x in circle_color.values()], 2 * 10 * scipy.sqrt(node.frequency), 2 * 10 * scipy.sqrt(node.frequency), colors=[(color if color != 'None' else 'lightgray') for color in list(circle_color.keys())], line_color=None) T = TextFace(' '.join( [str(x) for x in list(circle_color.values())]), tight_text=True) T.hz_align = 1 T.rotation = -90 faces.add_face_to_node(P, node, 0, position='branch-right') faces.add_face_to_node(T, node, 1, position='branch-right') if idlabel: T = TextFace(node.name, tight_text=True, fsize=6) T.rotation = -90 T.hz_align = 1 faces.add_face_to_node( T, node, 1 if isinstance(circle_color, str) else 2, position='branch-right') for node in self.tree.traverse(): nstyle = NodeStyle() nstyle['size'] = 0 if node.up is not None: if set(node.sequence.upper()) == set('ACGT'): if chain_split is not None: if self.frame is not None: raise NotImplementedError( 'frame not implemented with chain_split') leftseq_mutated = hamming_distance( node.sequence[:chain_split], node.up.sequence[:chain_split]) > 0 rightseq_mutated = hamming_distance( node.sequence[chain_split:], node.up.sequence[chain_split:]) > 0 if leftseq_mutated and rightseq_mutated: nstyle['hz_line_color'] = 'purple' nstyle['hz_line_width'] = 3 elif leftseq_mutated: nstyle['hz_line_color'] = 'red' nstyle['hz_line_width'] = 2 elif rightseq_mutated: nstyle['hz_line_color'] = 'blue' nstyle['hz_line_width'] = 2 if self.frame is not None: aa = Seq( node.sequence[(self.frame - 1):(self.frame - 1 + (3 * (((len(node.sequence) - (self.frame - 1)) // 3))))], generic_dna).translate() aa_parent = Seq( node.up.sequence[(self.frame - 1):(self.frame - 1 + (3 * (( (len(node.sequence) - (self.frame - 1)) // 3))))], generic_dna).translate() nonsyn = hamming_distance(aa, aa_parent) if '*' in aa: nstyle['bgcolor'] = 'red' if nonsyn > 0: nstyle['hz_line_color'] = 'black' nstyle['hz_line_width'] = nonsyn else: nstyle['hz_line_type'] = 1 node.set_style(nstyle) ts = TreeStyle() ts.show_leaf_name = False ts.rotation = 90 ts.draw_aligned_faces_as_table = False ts.allow_face_overlap = True ts.layout_fn = my_layout ts.show_scale = False ts.show_branch_support = show_support self.tree.render(outfile, tree_style=ts) # if we labelled seqs, let's also write the alignment out so we have the sequences (including of internal nodes) if idlabel: aln = MultipleSeqAlignment([]) for node in self.tree.traverse(): aln.append( SeqRecord(Seq(str(node.sequence), generic_dna), id=str(node.name), description='abundance={}'.format( node.frequency))) AlignIO.write(aln, open(os.path.splitext(outfile)[0] + '.fasta', 'w'), 'fasta') def write(self, file_name): '''serialize tree to file''' with open(file_name, 'wb') as f: pickle.dump(self, f) def compare(self, tree2, method='identity'): '''compare this tree to the other tree''' if method == 'identity': # we compare lists of seq, parent, abundance # return true if these lists are identical, else false list1 = sorted((node.sequence, node.frequency, node.up.sequence if node.up is not None else None) for node in self.tree.traverse()) list2 = sorted((node.sequence, node.frequency, node.up.sequence if node.up is not None else None) for node in tree2.tree.traverse()) return list1 == list2 elif method == 'MRCA': # matrix of hamming distance of common ancestors of taxa # takes a true and inferred tree as CollapsedTree objects taxa = [ node.sequence for node in self.tree.traverse() if node.frequency ] n_taxa = len(taxa) d = scipy.zeros(shape=(n_taxa, n_taxa)) sum_sites = scipy.zeros(shape=(n_taxa, n_taxa)) for i in range(n_taxa): nodei_true = self.tree.iter_search_nodes( sequence=taxa[i]).next() nodei = tree2.tree.iter_search_nodes(sequence=taxa[i]).next() for j in range(i + 1, n_taxa): nodej_true = self.tree.iter_search_nodes( sequence=taxa[j]).next() nodej = tree2.tree.iter_search_nodes( sequence=taxa[j]).next() MRCA_true = self.tree.get_common_ancestor( (nodei_true, nodej_true)).sequence MRCA = tree2.tree.get_common_ancestor( (nodei, nodej)).sequence d[i, j] = hamming_distance(MRCA_true, MRCA) sum_sites[i, j] = len(MRCA_true) return d.sum() / sum_sites.sum() elif method == 'RF': tree1_copy = self.tree.copy(method='deepcopy') tree2_copy = tree2.tree.copy(method='deepcopy') for treex in (tree1_copy, tree2_copy): for node in list(treex.traverse()): if node.frequency > 0: child = TreeNode() child.add_feature('sequence', node.sequence) node.add_child(child) try: return tree1_copy.robinson_foulds(tree2_copy, attr_t1='sequence', attr_t2='sequence', unrooted_trees=True)[0] except: return tree1_copy.robinson_foulds(tree2_copy, attr_t1='sequence', attr_t2='sequence', unrooted_trees=True, allow_dup=True)[0] else: raise ValueError('invalid distance method: ' + method) def get_split(self, node): '''return the bipartition resulting from clipping this node's edge above''' if node.get_tree_root() != self.tree: raise ValueError('node not found') if node == self.tree: raise ValueError('this node is the root (no split above)') parent = node.up taxa1 = [] for node2 in node.traverse(): if node2.frequency > 0 or node2 == self.tree: if isinstance(node2.name, str): taxa1.append(node2.name) else: taxa1.extend(node2.name) taxa1 = set(taxa1) node.detach() taxa2 = [] for node2 in self.tree.traverse(): if node2.frequency > 0 or node2 == self.tree: if isinstance(node2.name, str): taxa2.append(node2.name) else: taxa2.extend(node2.name) taxa2 = set(taxa2) parent.add_child(node) assert taxa1.isdisjoint(taxa2) assert taxa1.union(taxa2) == set( (name for node in self.tree.traverse() if node.frequency > 0 or node == self.tree for name in (( node.name, ) if isinstance(node.name, str) else node.name))) return tuple(sorted([taxa1, taxa2])) @staticmethod def split_compatibility(split1, split2): diff = split1[0].union(split1[1]) ^ split2[0].union(split2[1]) if diff: raise ValueError( 'splits do not cover the same taxa\n\ttaxa not in both: {}'. format(diff)) for partition1 in split1: for partition2 in split2: if partition1.isdisjoint(partition2): return True return False def support(self, bootstrap_trees_list, weights=None, compatibility=False): ''' compute support from a list of bootstrap GCtrees weights (optional) is needed for weighting parsimony degenerate trees compatibility mode counts trees that don't disconfirm the split ''' for node in self.tree.get_descendants(): split = self.get_split(node) support = 0 compatibility_ = 0 for i, tree in enumerate(bootstrap_trees_list): compatible = True supported = False for boot_node in tree.tree.get_descendants(): boot_split = tree.get_split(boot_node) if compatibility and compatible and not self.split_compatibility( split, boot_split): compatible = False if not compatibility and not supported and boot_split == split: supported = True if supported: support += weights[i] if weights is not None else 1 if compatible: compatibility_ += weights[i] if weights is not None else 1 node.support = compatibility_ if compatibility else support return self
def simulate( self, sequence: str, seq_bounds: Tuple[Tuple[int, int], Tuple[int, int]] = None, fitness_function: Callable = lambda seq: 0.9, lambda0: List[np.float64] = [1], frame: int = None, N_init: int = 1, N: int = None, T: int = None, n: int = None, verbose: bool = False, ) -> TreeNode: r"""Simulate a neutral binary branching process with the mutation model, returning a :class:`ete3.Treenode` object. Args: sequence: root nucleotide sequence seq_bounds: ranges for two subsequences used as two parallel genes fitness_function: mean number offspring as a function of sequence lambda0: baseline mutation rate(s) frame: coding frame of starting position(s) N_init: initial naive abundnace N: maximum population size T: maximum generation time n: sample size verbose: print more messages """ # Checking the validity of the input parameters: if N is not None and T is not None: raise ValueError( "Only one of N and T can be used. One must be None.") elif N is None and T is None: raise ValueError("Either N or T must be specified.") if N is not None and n is not None and n > N: raise ValueError("n ({}) must not larger than N ({})".format(n, N)) # Planting the tree: tree = TreeNode() tree.dist = 0 tree.add_feature("sequence", sequence) tree.add_feature("terminated", False) tree.add_feature("abundance", 0) tree.add_feature("time", 0) # add fitness attribute, interpreted as mean of offspring distribution tree.add_feature("fitness", fitness_function(tree.sequence)) if N_init > 1: for _ in range(N_init): child = TreeNode() child.dist = 0 child.add_feature("sequence", sequence) child.add_feature("abundance", 0) child.add_feature("terminated", False) child.add_feature("time", 0) # add fitness attribute, interpreted as mean of offspring distribution child.add_feature("fitness", fitness_function(child.sequence)) tree.add_child(child) t = 0 # <-- time leaves_unterminated = N_init while (leaves_unterminated > 0 and (leaves_unterminated < N if N is not None else True) and (t < max(T) if T is not None else True)): if verbose: print("At time:", t) t += 1 list_of_leaves = list(tree.iter_leaves()) random.shuffle(list_of_leaves) for leaf in list_of_leaves: # add fitness attribute, interpreted as mean of offspring distribution leaf.add_feature("fitness", fitness_function(leaf.sequence)) if not leaf.terminated: n_children = poisson(leaf.fitness).rvs() leaves_unterminated += ( n_children - 1 ) # <-- this kills the parent if we drew a zero if not n_children: leaf.terminated = True for child_count in range(n_children): # If sequence pair mutate them separately with their own mutation rate: if seq_bounds is not None: mutated_sequence1 = self.mutate( leaf. sequence[seq_bounds[0][0]:seq_bounds[0][1]], lambda0=lambda0[0], frame=frame, ) mutated_sequence2 = self.mutate( leaf. sequence[seq_bounds[1][0]:seq_bounds[1][1]], lambda0=lambda0[1], frame=frame, ) mutated_sequence = mutated_sequence1 + mutated_sequence2 else: mutated_sequence = self.mutate(leaf.sequence, lambda0=lambda0[0], frame=frame) child = TreeNode() child.dist = utils.hamming_distance( mutated_sequence, leaf.sequence) child.add_feature("sequence", mutated_sequence) child.add_feature("abundance", 0) child.add_feature("terminated", False) child.add_feature("time", t) leaf.add_child(child) if N is not None and leaves_unterminated < N: raise RuntimeError( "tree terminated with {} leaves, {} desired".format( leaves_unterminated, N)) # each leaf in final generation gets an observed abundance of 1, unless downsampled if T is not None and len(T) > 1: # Iterate the intermediate time steps: for Ti in sorted(T)[:-1]: # Only sample those that have been 'sampled' at intermediate sampling times: final_leaves = [ leaf for leaf in tree.iter_descendants() if leaf.time == Ti and leaf.sampled ] if len(final_leaves) < n: raise RuntimeError( "tree terminated with {} leaves, less than what desired after downsampling {}" .format(leaves_unterminated, n)) for (leaf) in ( final_leaves ): # No need to down-sample, this was already done in the simulation loop leaf.abundance = 1 # Do the normal sampling of the last time step: final_leaves = [leaf for leaf in tree.iter_leaves() if leaf.time == t] # by default, downsample to the target simulation size if n is not None and len(final_leaves) >= n: for leaf in random.sample(final_leaves, n): leaf.abundance = 1 elif n is None and N is not None: for leaf in random.sample(final_leaves, N): leaf.abundance = 1 elif N is None and T is not None: for leaf in final_leaves: leaf.abundance = 1 elif n is not None and len(final_leaves) < n: raise RuntimeError( "tree terminated with {} leaves, less than what desired after downsampling {}" .format(leaves_unterminated, n)) else: raise RuntimeError("Unknown option.") # prune away lineages that are unobserved for node in tree.iter_descendants(): if sum(node2.abundance for node2 in node.traverse()) == 0: node.detach() # # remove unobserved unifurcations # for node in tree.iter_descendants(): # parent = node.up # if node.abundance == 0 and len(node.children) == 1: # node.delete(prevent_nondicotomic=False) # node.children[0].dist = hamming_distance(node.children[0].sequence, parent.sequence) # assign unique names to each node for i, node in enumerate(tree.traverse(), 1): node.name = "simcell_{}".format(i) # return the uncollapsed tree return tree