def isotype_tree( tree: ete3.TreeNode, newidmap: Dict[str, Dict[str, str]], isotype_names: Sequence[str], weight_matrix: Optional[Sequence[Sequence[float]]] = None, ) -> ete3.TreeNode: """Method adds isotypes to ``tree``, minimizing isotype switching and obeying switching order. * Adds observed isotypes to each observed node in the collapsed trees output by gctree inference. If cells with the same sequence but different isotypes are observed, then collapsed tree nodes must be ‘exploded’ into new nodes with the appropriate isotypes and abundances. Each unique sequence ID generated by gctree is prepended to its observed isotype, and a new `isotyped.idmap` mapping these new sequence IDs to original sequence IDs is written in the output directory. * Resolves isotypes of unobserved ancestral genotypes in a way that minimizes isotype switching and obeys isotype switching order. If observed isotypes of an observed internal node and its children violate switching order, then the observed internal node is replaced with an unobserved node with the same sequence, and the observed internal node is placed as a child leaf. This procedure always allows switching order conflicts to be resolved, and should usually increase isotype transitions required in the resulting tree. Args: tree: ete3 Tree newidmap: mapping of sequence IDs to isotypes, such as that output by :meth:`utils.explode_idmap`. isotype_names: list or other sequence of isotype names observed, in correct switching order. Returns: A new ete3 Tree whose nodes have isotype annotations in the attribute ``isotype``. Node names in this tree also contain isotype names. """ tree = tree.copy() _add_observed_isotypes(tree, newidmap, isotype_names, weight_matrix=weight_matrix) _disambiguate_isotype(tree) _collapse_tree_by_sequence_and_isotype(tree) for node in tree.traverse(): node.name = str(node.name) + " " + str(node.isotype) for node in tree.iter_descendants(): node.dist = hamming_distance(node.up.sequence, node.sequence) return tree
def __init__(self, tree: ete3.TreeNode = None, allow_repeats: bool = False): if tree is not None: self.tree = tree.copy() # remove unobserved internal unifurcations for node in self.tree.iter_descendants(): parent = node.up if node.abundance == 0 and len(node.children) == 1: node.delete(prevent_nondicotomic=False) node.children[0].dist = utils.hamming_distance( node.children[0].sequence, parent.sequence) # iterate over the tree below root and collapse edges of zero # length if the node is a leaf and it's parent has nonzero # abundance we combine taxa names to a set to acommodate # bootstrap samples that result in repeated genotypes observed_genotypes = set((leaf.name for leaf in self.tree)) observed_genotypes.add(self.tree.name) for node in self.tree.get_descendants(strategy="postorder"): if node.dist == 0: node.up.abundance += node.abundance if isinstance(node.name, str): node_set = set([node.name]) else: node_set = set(node.name) if isinstance(node.up.name, str): node_up_set = set([node.up.name]) else: node_up_set = set(node.up.name) if node_up_set < observed_genotypes: if node_set < observed_genotypes: node.up.name = tuple(node_set | node_up_set) if len(node.up.name) == 1: node.up.name = node.up.name[0] elif node_set < observed_genotypes: node.up.name = tuple(node_set) if len(node.up.name) == 1: node.up.name = node.up.name[0] node.delete(prevent_nondicotomic=False) final_observed_genotypes = set() for node in self.tree.traverse(): if node.abundance > 0 or node == self.tree: for name in ((node.name, ) if isinstance(node.name, str) else node.name): final_observed_genotypes.add(name) if final_observed_genotypes != observed_genotypes: raise RuntimeError( "observed genotypes don't match after " f"collapse\n\tbefore: {observed_genotypes}" f"\n\tafter: {final_observed_genotypes}\n\t" "symmetric diff: " f"{observed_genotypes ^ final_observed_genotypes}") assert sum(node.abundance for node in tree.traverse()) == sum( node.abundance for node in self.tree.traverse()) rep_seq = sum( node.abundance > 0 for node in self.tree.traverse()) - len( set([ node.sequence for node in self.tree.traverse() if node.abundance > 0 ])) if not allow_repeats and rep_seq: raise RuntimeError( "Repeated observed sequences in collapsed " f"tree. {rep_seq} sequences were found repeated.") elif allow_repeats and rep_seq: rep_seq = sum(node.abundance > 0 for node in self.tree.traverse()) - len( set([ node.sequence for node in self.tree.traverse() if node.abundance > 0 ])) print("Repeated observed sequences in collapsed tree. " f"{rep_seq} sequences were found repeated.") # a custom ladderize accounting for abundance and sequence to break # ties in abundance for node in self.tree.traverse(strategy="postorder"): # add a partition feature and compute it recursively up tree node.add_feature( "partition", node.abundance + sum(node2.partition for node2 in node.children), ) # sort children of this node based on partion and sequence node.children.sort( key=lambda node: (node.partition, node.sequence)) # create list of (c, m) for each node self._cm_list = [(node.abundance, len(node.children)) for node in self.tree.traverse()] # store max c and m self._c_max = max(node.abundance for node in self.tree.traverse()) self._m_max = max( len(node.children) for node in self.tree.traverse()) else: self.tree = tree
def treeprint(tree: ete3.TreeNode): tree = tree.copy() for node in tree.traverse(): node.name = node.sequence return tree.write(format=8)
class CollapsedTree(LeavesAndClades): ''' Here's a derived class for a collapsed tree, where we recurse into the mutant clades (4) / | \\ (3)(1)(2) | \\ (2) (1) ''' def __init__(self, params=None, tree=None, frame=None, collapse_syn=False, allow_repeats=False): ''' For intialization, either params or tree (or both) must be provided params: offspring distribution parameters tree: ete tree with frequency node feature. If uncollapsed, it will be collapsed frame: tranlation frame, with default None, no tranlation attempted ''' LeavesAndClades.__init__(self, params=params) if frame is not None and frame not in (1, 2, 3): raise RuntimeError('frame must be 1, 2, 3, or None') self.frame = frame if collapse_syn is True: tree.dist = 0 # no branch above root for node in tree.iter_descendants(): aa = Seq( node.sequence[(frame - 1):(frame - 1 + (3 * (((len(node.sequence) - (frame - 1)) // 3))))], generic_dna).translate() aa_parent = Seq( node.up.sequence[(frame - 1):(frame - 1 + (3 * (((len(node.sequence) - (frame - 1)) // 3))))], generic_dna).translate() node.dist = hamming_distance(aa, aa_parent) if tree is not None: self.tree = tree.copy() # remove unobserved internal unifurcations for node in self.tree.iter_descendants(): parent = node.up if node.frequency == 0 and len(node.children) == 1: node.delete(prevent_nondicotomic=False) node.children[0].dist = hamming_distance( node.children[0].sequence, parent.sequence) # iterate over the tree below root and collapse edges of zero length # if the node is a leaf and it's parent has nonzero frequency we combine taxa names to a set # this acommodates bootstrap samples that result in repeated genotypes observed_genotypes = set((leaf.name for leaf in self.tree)) observed_genotypes.add(self.tree.name) for node in self.tree.get_descendants(strategy='postorder'): if node.dist == 0: node.up.frequency += node.frequency node_set = set([node.name]) if isinstance( node.name, str) else set(node.name) node_up_set = set([node.up.name]) if isinstance( node.up.name, str) else set(node.up.name) if node_up_set < observed_genotypes: if node_set < observed_genotypes: node.up.name = tuple(node_set | node_up_set) if len(node.up.name) == 1: node.up.name = node.up.name[0] elif node_set < observed_genotypes: node.up.name = tuple(node_set) if len(node.up.name) == 1: node.up.name = node.up.name[0] node.delete(prevent_nondicotomic=False) final_observed_genotypes = set([ name for node in self.tree.traverse() if node.frequency > 0 or node == self.tree for name in (( node.name, ) if isinstance(node.name, str) else node.name) ]) if final_observed_genotypes != observed_genotypes: raise RuntimeError( 'observed genotypes don\'t match after collapse\n\tbefore: {}\n\tafter: {}\n\tsymmetric diff: {}' .format(observed_genotypes, final_observed_genotypes, observed_genotypes ^ final_observed_genotypes)) assert sum(node.frequency for node in tree.traverse()) == sum( node.frequency for node in self.tree.traverse()) rep_seq = sum( node.frequency > 0 for node in self.tree.traverse()) - len( set([ node.sequence for node in self.tree.traverse() if node.frequency > 0 ])) if not allow_repeats and rep_seq: raise RuntimeError( 'Repeated observed sequences in collapsed tree. {} sequences were found repeated.' .format(rep_seq)) elif allow_repeats and rep_seq: rep_seq = sum(node.frequency > 0 for node in self.tree.traverse()) - len( set([ node.sequence for node in self.tree.traverse() if node.frequency > 0 ])) print( 'Repeated observed sequences in collapsed tree. {} sequences were found repeated.' .format(rep_seq)) # a custom ladderize accounting for abundance and sequence to break ties in abundance for node in self.tree.traverse(strategy='postorder'): # add a partition feature and compute it recursively up the tree node.add_feature( 'partition', node.frequency + sum(node2.partition for node2 in node.children)) # sort children of this node based on partion and sequence node.children.sort( key=lambda node: (node.partition, node.sequence)) else: self.tree = tree def l(self, params, sign=1): ''' log likelihood of params, conditioned on collapsed tree, and its gradient wrt params optional parameter sign must be 1 or -1, with the latter useful for MLE by minimization ''' if self.tree is None: raise ValueError('tree data must be defined to compute likelihood') if sign not in (-1, 1): raise ValueError('sign must be 1 or -1') leaves_and_clades_list = [ LeavesAndClades(c=node.frequency, m=len(node.children)) for node in self.tree.traverse() ] if leaves_and_clades_list[0].c == 0 and leaves_and_clades_list[ 0].m == 1 and leaves_and_clades_list[0].f(params)[0] == 0: # if unifurcation not possible under current model, add a psuedocount for the naive leaves_and_clades_list[0].c = 1 # extract vector of function values and gradient components f_data = [ leaves_and_clades.f(params) for leaves_and_clades in leaves_and_clades_list ] fs = scipy.array([[x[0]] for x in f_data]) logf = scipy.log(fs).sum() grad_fs = scipy.array([x[1] for x in f_data]) grad_logf = (grad_fs / fs).sum(axis=0) return sign * logf, sign * grad_logf def mle(self, **kwargs): ''' Maximum likelihood estimate for params given tree updates params if not None returns optimization result ''' # random initalization x_0 = (random.random(), random.random()) bounds = ((.01, .99), (.001, .999)) kwargs['sign'] = -1 grad_check = check_grad(lambda x: self.l(x, **kwargs)[0], lambda x: self.l(x, **kwargs)[1], (.4, .5)) if grad_check > 1e-3: warnings.warn( 'gradient mismatches finite difference approximation by {}'. format(grad_check), RuntimeWarning) result = minimize(lambda x: self.l(x, **kwargs), x0=x_0, jac=True, method='L-BFGS-B', options={'ftol': 1e-10}, bounds=bounds) # update params if None and optimization successful if not result.success: warnings.warn('optimization not sucessful, ' + result.message, RuntimeWarning) elif self.params is None: self.params = result.x.tolist() return result def simulate(self): ''' simulate a collapsed tree given params replaces existing tree data member with simulation result, and returns self ''' if self.params is None: raise ValueError('params must be defined for simulation') # initiate by running a LeavesAndClades simulation to get the number of clones and mutants # in the root node of the collapsed tree LeavesAndClades.simulate(self) self.tree = TreeNode() self.tree.add_feature('frequency', self.c) if self.m == 0: return self for _ in range(self.m): # ooooh, recursion child = CollapsedTree(params=self.params, frame=self.frame).simulate().tree child.dist = 1 self.tree.add_child(child) return self def __str__(self): '''return a string representation for printing''' return 'params = ' + str(self.params) + '\ntree:\n' + str(self.tree) def render(self, outfile, idlabel=False, colormap=None, show_support=False, chain_split=None): '''render to image file, filetype inferred from suffix, svg for color images''' def my_layout(node): circle_color = 'lightgray' if colormap is None or node.name not in colormap else colormap[ node.name] text_color = 'black' if isinstance(circle_color, str): C = CircleFace(radius=max(3, 10 * scipy.sqrt(node.frequency)), color=circle_color, label={ 'text': str(node.frequency), 'color': text_color } if node.frequency > 0 else None) C.rotation = -90 C.hz_align = 1 faces.add_face_to_node(C, node, 0) else: P = PieChartFace( [100 * x / node.frequency for x in circle_color.values()], 2 * 10 * scipy.sqrt(node.frequency), 2 * 10 * scipy.sqrt(node.frequency), colors=[(color if color != 'None' else 'lightgray') for color in list(circle_color.keys())], line_color=None) T = TextFace(' '.join( [str(x) for x in list(circle_color.values())]), tight_text=True) T.hz_align = 1 T.rotation = -90 faces.add_face_to_node(P, node, 0, position='branch-right') faces.add_face_to_node(T, node, 1, position='branch-right') if idlabel: T = TextFace(node.name, tight_text=True, fsize=6) T.rotation = -90 T.hz_align = 1 faces.add_face_to_node( T, node, 1 if isinstance(circle_color, str) else 2, position='branch-right') for node in self.tree.traverse(): nstyle = NodeStyle() nstyle['size'] = 0 if node.up is not None: if set(node.sequence.upper()) == set('ACGT'): if chain_split is not None: if self.frame is not None: raise NotImplementedError( 'frame not implemented with chain_split') leftseq_mutated = hamming_distance( node.sequence[:chain_split], node.up.sequence[:chain_split]) > 0 rightseq_mutated = hamming_distance( node.sequence[chain_split:], node.up.sequence[chain_split:]) > 0 if leftseq_mutated and rightseq_mutated: nstyle['hz_line_color'] = 'purple' nstyle['hz_line_width'] = 3 elif leftseq_mutated: nstyle['hz_line_color'] = 'red' nstyle['hz_line_width'] = 2 elif rightseq_mutated: nstyle['hz_line_color'] = 'blue' nstyle['hz_line_width'] = 2 if self.frame is not None: aa = Seq( node.sequence[(self.frame - 1):(self.frame - 1 + (3 * (((len(node.sequence) - (self.frame - 1)) // 3))))], generic_dna).translate() aa_parent = Seq( node.up.sequence[(self.frame - 1):(self.frame - 1 + (3 * (( (len(node.sequence) - (self.frame - 1)) // 3))))], generic_dna).translate() nonsyn = hamming_distance(aa, aa_parent) if '*' in aa: nstyle['bgcolor'] = 'red' if nonsyn > 0: nstyle['hz_line_color'] = 'black' nstyle['hz_line_width'] = nonsyn else: nstyle['hz_line_type'] = 1 node.set_style(nstyle) ts = TreeStyle() ts.show_leaf_name = False ts.rotation = 90 ts.draw_aligned_faces_as_table = False ts.allow_face_overlap = True ts.layout_fn = my_layout ts.show_scale = False ts.show_branch_support = show_support self.tree.render(outfile, tree_style=ts) # if we labelled seqs, let's also write the alignment out so we have the sequences (including of internal nodes) if idlabel: aln = MultipleSeqAlignment([]) for node in self.tree.traverse(): aln.append( SeqRecord(Seq(str(node.sequence), generic_dna), id=str(node.name), description='abundance={}'.format( node.frequency))) AlignIO.write(aln, open(os.path.splitext(outfile)[0] + '.fasta', 'w'), 'fasta') def write(self, file_name): '''serialize tree to file''' with open(file_name, 'wb') as f: pickle.dump(self, f) def compare(self, tree2, method='identity'): '''compare this tree to the other tree''' if method == 'identity': # we compare lists of seq, parent, abundance # return true if these lists are identical, else false list1 = sorted((node.sequence, node.frequency, node.up.sequence if node.up is not None else None) for node in self.tree.traverse()) list2 = sorted((node.sequence, node.frequency, node.up.sequence if node.up is not None else None) for node in tree2.tree.traverse()) return list1 == list2 elif method == 'MRCA': # matrix of hamming distance of common ancestors of taxa # takes a true and inferred tree as CollapsedTree objects taxa = [ node.sequence for node in self.tree.traverse() if node.frequency ] n_taxa = len(taxa) d = scipy.zeros(shape=(n_taxa, n_taxa)) sum_sites = scipy.zeros(shape=(n_taxa, n_taxa)) for i in range(n_taxa): nodei_true = self.tree.iter_search_nodes( sequence=taxa[i]).next() nodei = tree2.tree.iter_search_nodes(sequence=taxa[i]).next() for j in range(i + 1, n_taxa): nodej_true = self.tree.iter_search_nodes( sequence=taxa[j]).next() nodej = tree2.tree.iter_search_nodes( sequence=taxa[j]).next() MRCA_true = self.tree.get_common_ancestor( (nodei_true, nodej_true)).sequence MRCA = tree2.tree.get_common_ancestor( (nodei, nodej)).sequence d[i, j] = hamming_distance(MRCA_true, MRCA) sum_sites[i, j] = len(MRCA_true) return d.sum() / sum_sites.sum() elif method == 'RF': tree1_copy = self.tree.copy(method='deepcopy') tree2_copy = tree2.tree.copy(method='deepcopy') for treex in (tree1_copy, tree2_copy): for node in list(treex.traverse()): if node.frequency > 0: child = TreeNode() child.add_feature('sequence', node.sequence) node.add_child(child) try: return tree1_copy.robinson_foulds(tree2_copy, attr_t1='sequence', attr_t2='sequence', unrooted_trees=True)[0] except: return tree1_copy.robinson_foulds(tree2_copy, attr_t1='sequence', attr_t2='sequence', unrooted_trees=True, allow_dup=True)[0] else: raise ValueError('invalid distance method: ' + method) def get_split(self, node): '''return the bipartition resulting from clipping this node's edge above''' if node.get_tree_root() != self.tree: raise ValueError('node not found') if node == self.tree: raise ValueError('this node is the root (no split above)') parent = node.up taxa1 = [] for node2 in node.traverse(): if node2.frequency > 0 or node2 == self.tree: if isinstance(node2.name, str): taxa1.append(node2.name) else: taxa1.extend(node2.name) taxa1 = set(taxa1) node.detach() taxa2 = [] for node2 in self.tree.traverse(): if node2.frequency > 0 or node2 == self.tree: if isinstance(node2.name, str): taxa2.append(node2.name) else: taxa2.extend(node2.name) taxa2 = set(taxa2) parent.add_child(node) assert taxa1.isdisjoint(taxa2) assert taxa1.union(taxa2) == set( (name for node in self.tree.traverse() if node.frequency > 0 or node == self.tree for name in (( node.name, ) if isinstance(node.name, str) else node.name))) return tuple(sorted([taxa1, taxa2])) @staticmethod def split_compatibility(split1, split2): diff = split1[0].union(split1[1]) ^ split2[0].union(split2[1]) if diff: raise ValueError( 'splits do not cover the same taxa\n\ttaxa not in both: {}'. format(diff)) for partition1 in split1: for partition2 in split2: if partition1.isdisjoint(partition2): return True return False def support(self, bootstrap_trees_list, weights=None, compatibility=False): ''' compute support from a list of bootstrap GCtrees weights (optional) is needed for weighting parsimony degenerate trees compatibility mode counts trees that don't disconfirm the split ''' for node in self.tree.get_descendants(): split = self.get_split(node) support = 0 compatibility_ = 0 for i, tree in enumerate(bootstrap_trees_list): compatible = True supported = False for boot_node in tree.tree.get_descendants(): boot_split = tree.get_split(boot_node) if compatibility and compatible and not self.split_compatibility( split, boot_split): compatible = False if not compatibility and not supported and boot_split == split: supported = True if supported: support += weights[i] if weights is not None else 1 if compatible: compatibility_ += weights[i] if weights is not None else 1 node.support = compatibility_ if compatibility else support return self