def make_oldctree(tree): """Make an old CollapsedTree from an hDAG clade tree""" etetree = tree.to_ete( name_func=lambda n: n.attr["name"], features=["sequence"], feature_funcs={"abundance": lambda n: n.attr["abundance"]}, ) for node in etetree.traverse(): if not node.is_leaf(): node.abundance = 0 etetree.dist = 0 for node in etetree.iter_descendants(): node.dist = utils.hamming_distance(node.up.sequence, node.sequence) return OldCollapsedTree(etetree)
def isotype_tree( tree: ete3.TreeNode, newidmap: Dict[str, Dict[str, str]], isotype_names: Sequence[str], weight_matrix: Optional[Sequence[Sequence[float]]] = None, ) -> ete3.TreeNode: """Method adds isotypes to ``tree``, minimizing isotype switching and obeying switching order. * Adds observed isotypes to each observed node in the collapsed trees output by gctree inference. If cells with the same sequence but different isotypes are observed, then collapsed tree nodes must be ‘exploded’ into new nodes with the appropriate isotypes and abundances. Each unique sequence ID generated by gctree is prepended to its observed isotype, and a new `isotyped.idmap` mapping these new sequence IDs to original sequence IDs is written in the output directory. * Resolves isotypes of unobserved ancestral genotypes in a way that minimizes isotype switching and obeys isotype switching order. If observed isotypes of an observed internal node and its children violate switching order, then the observed internal node is replaced with an unobserved node with the same sequence, and the observed internal node is placed as a child leaf. This procedure always allows switching order conflicts to be resolved, and should usually increase isotype transitions required in the resulting tree. Args: tree: ete3 Tree newidmap: mapping of sequence IDs to isotypes, such as that output by :meth:`utils.explode_idmap`. isotype_names: list or other sequence of isotype names observed, in correct switching order. Returns: A new ete3 Tree whose nodes have isotype annotations in the attribute ``isotype``. Node names in this tree also contain isotype names. """ tree = tree.copy() _add_observed_isotypes(tree, newidmap, isotype_names, weight_matrix=weight_matrix) _disambiguate_isotype(tree) _collapse_tree_by_sequence_and_isotype(tree) for node in tree.traverse(): node.name = str(node.name) + " " + str(node.isotype) for node in tree.iter_descendants(): node.dist = hamming_distance(node.up.sequence, node.sequence) return tree
def align_lineages(seq, tree_t, tree_i, gap_penalty_pct=0, known_root=True, allow_double_gap=False): """Standard implementation of a Needleman-Wunsch algorithm as described here: http://telliott99.blogspot.com/2009/08/alignment-needleman- wunsch.html https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm And implemented here: https://github.com/alevchuk/pairwise-alignment-in- python/blob/master/alignment.py. gap_penalty_pct is the gap penalty relative to the sequence length of the sequences on the tree. """ nt = find_node_by_seq(tree_t, seq) lt = reconstruct_lineage(tree_t, nt) ni = find_node_by_seq(tree_i, seq) li = reconstruct_lineage(tree_i, ni) # One lineages must be longer than just the root and the terminal node if len(lt) <= 2 and len(li) <= 2: return False # Gap penalty chosen not too large: gap_penalty = -1 * int((len(seq) / 100.0) * gap_penalty_pct) assert gap_penalty <= 0 # Penalties must be negative if ( gap_penalty == 0 ): # If gap penalty is zero only gaps in the shortes sequence will be allowed assert allow_double_gap is False # Generate a score matrix matrix: kt = len(lt) ki = len(li) # Disallow gaps in the longest list: if allow_double_gap is False and kt > ki: # If true is longer than inferred allow gap only in inferred: gap_penalty_i = gap_penalty gap_penalty_j = -1 * float("inf") elif allow_double_gap is False and kt < ki: # If inferred is longer than true allow gap only in true: gap_penalty_i = -1 * float("inf") gap_penalty_j = gap_penalty elif allow_double_gap is False and kt == ki: # If lists are equally long no gaps are allowed: gap_penalty_i = -1 * float("inf") gap_penalty_j = -1 * float("inf") else: gap_penalty_i = gap_penalty gap_penalty_j = gap_penalty sc_mat = np.zeros((kt, ki), dtype=np.float64) for i in range(kt): for j in range(ki): # Notice the score is defined by number of mismatches: # sc_mat[i, j] = len(lt[i]) - hamming_distance(lt[i], li[j]) sc_mat[i, j] = -1 * hamming_distance(lt[i], li[j]) ### print(sc_mat) # Calculate the alignment scores: aln_sc = np.zeros((kt + 1, ki + 1), dtype=np.float64) for i in range(0, kt + 1): if known_root is True: aln_sc[i][0] = -1 * float("inf") else: aln_sc[i][0] = gap_penalty_i * i for j in range(0, ki + 1): if known_root is True: aln_sc[0][j] = -1 * float("inf") else: aln_sc[0][j] = gap_penalty_j * j aln_sc[0][0] = 0 # The top left is fixed to zero ### print(aln_sc) for i in range(1, kt + 1): for j in range(1, ki + 1): match = aln_sc[i - 1][j - 1] + sc_mat[i - 1, j - 1] gap_in_inferred = aln_sc[i - 1][j] + gap_penalty_i gap_in_true = aln_sc[i][j - 1] + gap_penalty_j aln_sc[i][j] = max(match, gap_in_inferred, gap_in_true) ### print(aln_sc) # Traceback to compute the alignment: align_t, align_i, asr_align = list(), list(), list() i, j = kt, ki alignment_score = aln_sc[i][j] while i > 0 and j > 0: sc_current = aln_sc[i][j] sc_diagonal = aln_sc[i - 1][j - 1] sc_up = aln_sc[i][j - 1] sc_left = aln_sc[i - 1][j] if sc_current == (sc_diagonal + sc_mat[i - 1, j - 1]): align_t.append(lt[i - 1]) align_i.append(li[j - 1]) i -= 1 j -= 1 elif sc_current == (sc_left + gap_penalty_i): align_t.append(lt[i - 1]) align_i.append("-") i -= 1 elif sc_current == (sc_up + gap_penalty_j): align_t.append("-") align_i.append(li[j - 1]) j -= 1 # If space left fill it with gaps: while i > 0: asr_align.append(gap_penalty_i) align_t.append(lt[i - 1]) align_i.append("-") i -= 1 while j > 0: asr_align.append(gap_penalty_j) align_t.append("-") align_i.append(li[j - 1]) j -= 1 max_penalty = 0 for a, b in zip(align_t, align_i): if a == "-" or b == "-": max_penalty += gap_penalty else: max_penalty += -len(a) # Notice that the root and the terminal node is excluded from this comparison. # by adding their length to the max_penalty: if known_root is True: max_penalty += 2 * len(lt[0]) else: # Or in the case of an unknown root, just add the terminal node max_penalty += len(lt[0]) return [align_t, align_i, alignment_score, max_penalty]
def main(): parser = argparse.ArgumentParser( description="summary statistics of pre-tree data") parser.add_argument("input", type=str, nargs="+", help="simulated fasta files") parser.add_argument("--experimental", type=str, nargs="*", help="experimental fasta files") parser.add_argument("--outbase", type=str, help="output file base name") parser.add_argument("--root_idexp", type=str, default="root0", help="root sequence ID") args = parser.parse_args() # simulations root_id = "root" for i, fname in enumerate(args.input): print(fname) seqs = {seq.id: str(seq.seq) for seq in fasta_parse(fname, "root")[0]} nseqs = len(seqs) if nseqs <= 2: continue distance_from_root, degree = zip(*[( hamming_distance(seqs[seqid], seqs[root_id]), min( hamming_distance(seqs[seqid], seqs[seqid2]) for seqid2 in seqs if seqid2 != root_id and seqid2 != seqid), ) for seqid in seqs if seqid != root_id]) df = pd.DataFrame({ "distance to root sequence": distance_from_root, "nearest neighbor distance": degree, }) df["data set"] = i + 1 if i == 0: aggdat = df else: aggdat = aggdat.append(df, ignore_index=True) ndatasets = len(set(aggdat["data set"])) # experimental if args.experimental is not None: for i, fname in enumerate(args.experimental): print(fname) seqs = { seq.id: str(seq.seq) for seq in fasta_parse(fname, args.root_idexp)[0] } nseqs = len(seqs) if nseqs <= 2: continue distance_from_root, degree = zip(*[( hamming_distance(seqs[seqid], seqs[args.root_idexp]), min( hamming_distance(seqs[seqid], seqs[seqid2]) for seqid2 in seqs if seqid2 != args.root_idexp and seqid2 != seqid), ) for seqid in seqs if seqid != args.root_idexp]) df = pd.DataFrame({ "distance to root sequence": distance_from_root, "nearest neighbor distance": degree, }) df["data set"] = i + 1 if i == 0: aggdat_exp = df else: aggdat_exp = aggdat_exp.append(df, ignore_index=True) ndatasets += len(set(aggdat_exp["data set"])) # bw = .3 alpha = min([0.9, 20 / ndatasets]) bins = range( max( aggdat["distance to root sequence"].max(), aggdat_exp["distance to root sequence"].max() if args. experimental is not None else 0, ) + 2) plt.figure(figsize=(6, 3)) plt.subplot(1, 2, 1) ct = 0 for dataset, dataset_aggdat in aggdat.groupby("data set"): ct += 1 sns.distplot( dataset_aggdat["distance to root sequence"], bins=bins, kde=False, color="gray", hist_kws={ "histtype": "step", "cumulative": True, "alpha": alpha, "lw": 1 }, ) if args.experimental is not None: for dataset, dataset_aggdat in aggdat_exp.groupby("data set"): ct += 1 sns.distplot( dataset_aggdat["distance to root sequence"], bins=bins, kde=False, color="black", hist_kws={ "histtype": "step", "cumulative": True, "alpha": 0.5, "lw": 3, }, ) plt.xlabel("distance to root sequence") plt.xlim([0, bins[-1]]) plt.ylabel("observed sequences") plt.tight_layout() bins = range( max( aggdat["nearest neighbor distance"].max(), aggdat_exp["nearest neighbor distance"].max() if args. experimental is not None else 0, ) + 2) plt.subplot(1, 2, 2) ct = 0 for dataset, dataset_aggdat in aggdat.groupby("data set"): ct += 1 sns.distplot( dataset_aggdat["nearest neighbor distance"], bins=bins, kde=False, color="gray", hist_kws={ "histtype": "step", "cumulative": True, "alpha": alpha, "lw": 1 }, ) if args.experimental is not None: for dataset, dataset_aggdat in aggdat_exp.groupby("data set"): ct += 1 sns.distplot( dataset_aggdat["nearest neighbor distance"], bins=bins, kde=False, color="black", hist_kws={ "histtype": "step", "cumulative": True, "alpha": 0.5, "lw": 3, }, ) plt.xlabel("nearest neighbor distance") plt.xlim([0, bins[-1]]) plt.ylabel("") plt.tight_layout() plt.savefig(args.outbase + ".pdf")
aggdat = df else: aggdat = aggdat.append(df, ignore_index=True) sims = set(aggdat["simulation"]) nsims = len(sims) if args.experimental is not None: new_aln, counts = fasta_parse(args.experimental, root="GL", id_abundances=True)[:2] exp_dict = {seq.id: str(seq.seq) for seq in new_aln} root_id = [seq for seq in exp_dict if "gl" in seq][0] abundance, distance_from_root, degree = zip(*[( counts[seq], hamming_distance(exp_dict[seq], exp_dict[root_id]), sum( hamming_distance(exp_dict[seq], exp_dict[seq2]) == 1 for seq2 in exp_dict if seq2 is not seq and counts[seq2] != 0), ) for seq in exp_dict if counts[seq] != 0]) exp_stats = pd.DataFrame({ "genotype abundance": abundance, "Hamming distance to root genotype": distance_from_root, "Hamming neighbor genotypes": degree, }) # bw = .3 alpha = min([0.9, 20 / nsims]) bins = range( max( aggdat["Hamming distance to root genotype"].max(),
def simulate(args): """Simulation subprogram. Simulates a Galton–Watson process, with mutation probabilities according to a user defined motif model e.g. S5F """ random.seed(a=args.seed) mutation_model = mm.MutationModel(args.mutability, args.substitution) if args.lambda0 is None: args.lambda0 = [max([1, int(0.01 * len(args.sequence))])] args.sequence = args.sequence.upper() if args.sequence2 is not None: # Use the same mutation rate on both sequences if len(args.lambda0) == 1: args.lambda0 = [args.lambda0[0], args.lambda0[0]] elif len(args.lambda0) != 2: raise Exception("Only one or two lambda0 can be defined for a two " "sequence simulation.") # Require both sequences to be in frame 1: if args.frame is not None and args.frame != 1: if args.verbose: print("Warning: When simulating with two sequences they are " "truncated to be beginning at frame 1.") args.sequence = args.sequence[(args.frame - 1):(args.frame - 1 + (3 * (((len(args.sequence) - (args.frame - 1)) // 3))))] args.sequence2 = args.sequence2[(args.frame - 1):( args.frame - 1 + (3 * (((len(args.sequence2) - (args.frame - 1)) // 3))))] # Extract the bounds between sequence 1 and 2: seq_bounds = ( (0, len(args.sequence)), (len(args.sequence), len(args.sequence) + len(args.sequence2)), ) # Merge the two seqeunces to simplify future dealing with the pair: args.sequence += args.sequence2 else: seq_bounds = None trials = 1000 # this loop makes us resimulate if size too small, or backmutation for trial in range(trials): try: tree = mutation_model.simulate( args.sequence, seq_bounds=seq_bounds, progeny=lambda seq: args.lambda_, lambda0=args.lambda0, n=args.n, N=args.N, T=args.T, frame=args.frame, verbose=args.verbose, ) # this will fail if backmutations collapsed_tree = bp.CollapsedTree(tree=tree) tree.ladderize() uniques = sum(node.abundance > 0 for node in collapsed_tree.tree.traverse()) if uniques < 2: raise RuntimeError(f"collapsed tree contains {uniques} " "sampled sequences") break except RuntimeError as e: print(f"{e}, trying again") else: raise if trial == trials - 1: raise RuntimeError(f"{trials} attempts exceeded") # In the case of a sequence pair print them to separate files: if args.sequence2 is not None: fh1 = open(args.outbase + ".simulation_seq1.fasta", "w") fh2 = open(args.outbase + ".simulation_seq2.fasta", "w") fh1.write(">root\n") fh1.write(args.sequence[seq_bounds[0][0]:seq_bounds[0][1]] + "\n") fh2.write(">root\n") fh2.write(args.sequence[seq_bounds[1][0]:seq_bounds[1][1]] + "\n") for leaf in tree.iter_leaves(): if leaf.abundance != 0: fh1.write(">" + leaf.name + "\n") fh1.write(leaf.sequence[seq_bounds[0][0]:seq_bounds[0][1]] + "\n") fh2.write(">" + leaf.name + "\n") fh2.write(leaf.sequence[seq_bounds[1][0]:seq_bounds[1][1]] + "\n") else: with open(args.outbase + ".simulation.fasta", "w") as f: f.write(">root\n") f.write(args.sequence + "\n") for leaf in tree.iter_leaves(): if leaf.abundance != 0: f.write(">" + leaf.name + "\n") f.write(leaf.sequence + "\n") # some observable simulation stats to write abundance, distance_from_root, degree = zip(*[( node.abundance, utils.hamming_distance(node.sequence, args.sequence), sum( utils.hamming_distance(node.sequence, node2.sequence) == 1 for node2 in collapsed_tree.tree.traverse() if node2.abundance and node2 is not node), ) for node in collapsed_tree.tree.traverse() if node.abundance]) stats = pd.DataFrame({ "genotype abundance": abundance, "Hamming distance to root genotype": distance_from_root, "Hamming neighbor genotypes": degree, }) stats.to_csv(args.outbase + ".simulation.stats.tsv", sep="\t", index=False) print(f"{sum(leaf.abundance for leaf in collapsed_tree.tree.traverse())}" " simulated observed sequences") # render the full lineage tree ts = ete3.TreeStyle() ts.rotation = 90 ts.show_leaf_name = False ts.show_scale = False colors = {} palette = ete3.SVG_COLORS palette -= set(["black", "white", "gray"]) palette = itertools.cycle(list(palette)) # <-- circular iterator colors[tree.sequence] = "gray" for n in tree.traverse(): nstyle = ete3.NodeStyle() nstyle["size"] = 10 if args.plotAA: if n.AAseq not in colors: colors[n.AAseq] = next(palette) nstyle["fgcolor"] = colors[n.AAseq] else: if n.sequence not in colors: colors[n.sequence] = next(palette) nstyle["fgcolor"] = colors[n.sequence] n.set_style(nstyle) # this makes the rendered branch lenths correspond to time for node in tree.iter_descendants(): node.dist = node.time - node.up.time tree.render(args.outbase + ".simulation.lineage_tree.svg", tree_style=ts) # render collapsed tree # create an id-wise colormap # NOTE: node.name can be a set colormap = { node.name: colors[node.sequence] for node in collapsed_tree.tree.traverse() } collapsed_tree.write(args.outbase + ".simulation.collapsed_tree.p") collapsed_tree.render( args.outbase + ".simulation.collapsed_tree.svg", idlabel=args.idlabel, colormap=colormap, frame=args.frame, ) # print colormap to file with open(args.outbase + ".simulation.collapsed_tree.colormap.tsv", "w") as f: for name, color in colormap.items(): f.write((name if isinstance(name, str) else ",".join(name)) + "\t" + color + "\n")
def __init__(self, tree: ete3.TreeNode = None, allow_repeats: bool = False): if tree is not None: self.tree = tree.copy() # remove unobserved internal unifurcations for node in self.tree.iter_descendants(): parent = node.up if node.abundance == 0 and len(node.children) == 1: node.delete(prevent_nondicotomic=False) node.children[0].dist = utils.hamming_distance( node.children[0].sequence, parent.sequence) # iterate over the tree below root and collapse edges of zero # length if the node is a leaf and it's parent has nonzero # abundance we combine taxa names to a set to acommodate # bootstrap samples that result in repeated genotypes observed_genotypes = set((leaf.name for leaf in self.tree)) observed_genotypes.add(self.tree.name) for node in self.tree.get_descendants(strategy="postorder"): if node.dist == 0: node.up.abundance += node.abundance if isinstance(node.name, str): node_set = set([node.name]) else: node_set = set(node.name) if isinstance(node.up.name, str): node_up_set = set([node.up.name]) else: node_up_set = set(node.up.name) if node_up_set < observed_genotypes: if node_set < observed_genotypes: node.up.name = tuple(node_set | node_up_set) if len(node.up.name) == 1: node.up.name = node.up.name[0] elif node_set < observed_genotypes: node.up.name = tuple(node_set) if len(node.up.name) == 1: node.up.name = node.up.name[0] node.delete(prevent_nondicotomic=False) final_observed_genotypes = set() for node in self.tree.traverse(): if node.abundance > 0 or node == self.tree: for name in ((node.name, ) if isinstance(node.name, str) else node.name): final_observed_genotypes.add(name) if final_observed_genotypes != observed_genotypes: raise RuntimeError( "observed genotypes don't match after " f"collapse\n\tbefore: {observed_genotypes}" f"\n\tafter: {final_observed_genotypes}\n\t" "symmetric diff: " f"{observed_genotypes ^ final_observed_genotypes}") assert sum(node.abundance for node in tree.traverse()) == sum( node.abundance for node in self.tree.traverse()) rep_seq = sum( node.abundance > 0 for node in self.tree.traverse()) - len( set([ node.sequence for node in self.tree.traverse() if node.abundance > 0 ])) if not allow_repeats and rep_seq: raise RuntimeError( "Repeated observed sequences in collapsed " f"tree. {rep_seq} sequences were found repeated.") elif allow_repeats and rep_seq: rep_seq = sum(node.abundance > 0 for node in self.tree.traverse()) - len( set([ node.sequence for node in self.tree.traverse() if node.abundance > 0 ])) print("Repeated observed sequences in collapsed tree. " f"{rep_seq} sequences were found repeated.") # a custom ladderize accounting for abundance and sequence to break # ties in abundance for node in self.tree.traverse(strategy="postorder"): # add a partition feature and compute it recursively up tree node.add_feature( "partition", node.abundance + sum(node2.partition for node2 in node.children), ) # sort children of this node based on partion and sequence node.children.sort( key=lambda node: (node.partition, node.sequence)) # create list of (c, m) for each node self._cm_list = [(node.abundance, len(node.children)) for node in self.tree.traverse()] # store max c and m self._c_max = max(node.abundance for node in self.tree.traverse()) self._m_max = max( len(node.children) for node in self.tree.traverse()) else: self.tree = tree
def simulate( self, sequence: str, seq_bounds: Tuple[Tuple[int, int], Tuple[int, int]] = None, fitness_function: Callable = lambda seq: 0.9, lambda0: List[np.float64] = [1], frame: int = None, N_init: int = 1, N: int = None, T: int = None, n: int = None, verbose: bool = False, ) -> TreeNode: r"""Simulate a neutral binary branching process with the mutation model, returning a :class:`ete3.Treenode` object. Args: sequence: root nucleotide sequence seq_bounds: ranges for two subsequences used as two parallel genes fitness_function: mean number offspring as a function of sequence lambda0: baseline mutation rate(s) frame: coding frame of starting position(s) N_init: initial naive abundnace N: maximum population size T: maximum generation time n: sample size verbose: print more messages """ # Checking the validity of the input parameters: if N is not None and T is not None: raise ValueError( "Only one of N and T can be used. One must be None.") elif N is None and T is None: raise ValueError("Either N or T must be specified.") if N is not None and n is not None and n > N: raise ValueError("n ({}) must not larger than N ({})".format(n, N)) # Planting the tree: tree = TreeNode() tree.dist = 0 tree.add_feature("sequence", sequence) tree.add_feature("terminated", False) tree.add_feature("abundance", 0) tree.add_feature("time", 0) # add fitness attribute, interpreted as mean of offspring distribution tree.add_feature("fitness", fitness_function(tree.sequence)) if N_init > 1: for _ in range(N_init): child = TreeNode() child.dist = 0 child.add_feature("sequence", sequence) child.add_feature("abundance", 0) child.add_feature("terminated", False) child.add_feature("time", 0) # add fitness attribute, interpreted as mean of offspring distribution child.add_feature("fitness", fitness_function(child.sequence)) tree.add_child(child) t = 0 # <-- time leaves_unterminated = N_init while (leaves_unterminated > 0 and (leaves_unterminated < N if N is not None else True) and (t < max(T) if T is not None else True)): if verbose: print("At time:", t) t += 1 list_of_leaves = list(tree.iter_leaves()) random.shuffle(list_of_leaves) for leaf in list_of_leaves: # add fitness attribute, interpreted as mean of offspring distribution leaf.add_feature("fitness", fitness_function(leaf.sequence)) if not leaf.terminated: n_children = poisson(leaf.fitness).rvs() leaves_unterminated += ( n_children - 1 ) # <-- this kills the parent if we drew a zero if not n_children: leaf.terminated = True for child_count in range(n_children): # If sequence pair mutate them separately with their own mutation rate: if seq_bounds is not None: mutated_sequence1 = self.mutate( leaf. sequence[seq_bounds[0][0]:seq_bounds[0][1]], lambda0=lambda0[0], frame=frame, ) mutated_sequence2 = self.mutate( leaf. sequence[seq_bounds[1][0]:seq_bounds[1][1]], lambda0=lambda0[1], frame=frame, ) mutated_sequence = mutated_sequence1 + mutated_sequence2 else: mutated_sequence = self.mutate(leaf.sequence, lambda0=lambda0[0], frame=frame) child = TreeNode() child.dist = utils.hamming_distance( mutated_sequence, leaf.sequence) child.add_feature("sequence", mutated_sequence) child.add_feature("abundance", 0) child.add_feature("terminated", False) child.add_feature("time", t) leaf.add_child(child) if N is not None and leaves_unterminated < N: raise RuntimeError( "tree terminated with {} leaves, {} desired".format( leaves_unterminated, N)) # each leaf in final generation gets an observed abundance of 1, unless downsampled if T is not None and len(T) > 1: # Iterate the intermediate time steps: for Ti in sorted(T)[:-1]: # Only sample those that have been 'sampled' at intermediate sampling times: final_leaves = [ leaf for leaf in tree.iter_descendants() if leaf.time == Ti and leaf.sampled ] if len(final_leaves) < n: raise RuntimeError( "tree terminated with {} leaves, less than what desired after downsampling {}" .format(leaves_unterminated, n)) for (leaf) in ( final_leaves ): # No need to down-sample, this was already done in the simulation loop leaf.abundance = 1 # Do the normal sampling of the last time step: final_leaves = [leaf for leaf in tree.iter_leaves() if leaf.time == t] # by default, downsample to the target simulation size if n is not None and len(final_leaves) >= n: for leaf in random.sample(final_leaves, n): leaf.abundance = 1 elif n is None and N is not None: for leaf in random.sample(final_leaves, N): leaf.abundance = 1 elif N is None and T is not None: for leaf in final_leaves: leaf.abundance = 1 elif n is not None and len(final_leaves) < n: raise RuntimeError( "tree terminated with {} leaves, less than what desired after downsampling {}" .format(leaves_unterminated, n)) else: raise RuntimeError("Unknown option.") # prune away lineages that are unobserved for node in tree.iter_descendants(): if sum(node2.abundance for node2 in node.traverse()) == 0: node.detach() # # remove unobserved unifurcations # for node in tree.iter_descendants(): # parent = node.up # if node.abundance == 0 and len(node.children) == 1: # node.delete(prevent_nondicotomic=False) # node.children[0].dist = hamming_distance(node.children[0].sequence, parent.sequence) # assign unique names to each node for i, node in enumerate(tree.traverse(), 1): node.name = "simcell_{}".format(i) # return the uncollapsed tree return tree