Exemplo n.º 1
0
def isotype_tree(
    tree: ete3.TreeNode,
    newidmap: Dict[str, Dict[str, str]],
    isotype_names: Sequence[str],
    weight_matrix: Optional[Sequence[Sequence[float]]] = None,
) -> ete3.TreeNode:
    """Method adds isotypes to ``tree``, minimizing isotype switching and
    obeying switching order.

    * Adds observed isotypes to each observed node in the collapsed
      trees output by gctree inference. If cells with the same sequence
      but different isotypes are observed, then collapsed tree nodes
      must be ‘exploded’ into new nodes with the appropriate isotypes
      and abundances. Each unique sequence ID generated by gctree is
      prepended to its observed isotype, and a new `isotyped.idmap`
      mapping these new sequence IDs to original sequence IDs is
      written in the output directory.
    * Resolves isotypes of unobserved ancestral genotypes in a way
      that minimizes isotype switching and obeys isotype switching
      order. If observed isotypes of an observed internal node and its
      children violate switching order, then the observed internal node
      is replaced with an unobserved node with the same sequence, and
      the observed internal node is placed as a child leaf. This
      procedure always allows switching order conflicts to be resolved,
      and should usually increase isotype transitions required in the
      resulting tree.

    Args:
        tree: ete3 Tree
        newidmap: mapping of sequence IDs to isotypes, such as that output by :meth:`utils.explode_idmap`.
        isotype_names: list or other sequence of isotype names observed, in correct switching order.

    Returns:
        A new ete3 Tree whose nodes have isotype annotations in the attribute ``isotype``.
        Node names in this tree also contain isotype names.
    """
    tree = tree.copy()
    _add_observed_isotypes(tree,
                           newidmap,
                           isotype_names,
                           weight_matrix=weight_matrix)
    _disambiguate_isotype(tree)
    _collapse_tree_by_sequence_and_isotype(tree)
    for node in tree.traverse():
        node.name = str(node.name) + " " + str(node.isotype)
    for node in tree.iter_descendants():
        node.dist = hamming_distance(node.up.sequence, node.sequence)
    return tree
Exemplo n.º 2
0
    def __init__(self,
                 tree: ete3.TreeNode = None,
                 allow_repeats: bool = False):
        if tree is not None:
            self.tree = tree.copy()
            # remove unobserved internal unifurcations
            for node in self.tree.iter_descendants():
                parent = node.up
                if node.abundance == 0 and len(node.children) == 1:
                    node.delete(prevent_nondicotomic=False)
                    node.children[0].dist = utils.hamming_distance(
                        node.children[0].sequence, parent.sequence)

            # iterate over the tree below root and collapse edges of zero
            # length if the node is a leaf and it's parent has nonzero
            # abundance we combine taxa names to a set to acommodate
            # bootstrap samples that result in repeated genotypes
            observed_genotypes = set((leaf.name for leaf in self.tree))
            observed_genotypes.add(self.tree.name)
            for node in self.tree.get_descendants(strategy="postorder"):
                if node.dist == 0:
                    node.up.abundance += node.abundance
                    if isinstance(node.name, str):
                        node_set = set([node.name])
                    else:
                        node_set = set(node.name)
                    if isinstance(node.up.name, str):
                        node_up_set = set([node.up.name])
                    else:
                        node_up_set = set(node.up.name)
                    if node_up_set < observed_genotypes:
                        if node_set < observed_genotypes:
                            node.up.name = tuple(node_set | node_up_set)
                            if len(node.up.name) == 1:
                                node.up.name = node.up.name[0]
                    elif node_set < observed_genotypes:
                        node.up.name = tuple(node_set)
                        if len(node.up.name) == 1:
                            node.up.name = node.up.name[0]
                    node.delete(prevent_nondicotomic=False)

            final_observed_genotypes = set()
            for node in self.tree.traverse():
                if node.abundance > 0 or node == self.tree:
                    for name in ((node.name, )
                                 if isinstance(node.name, str) else node.name):
                        final_observed_genotypes.add(name)
            if final_observed_genotypes != observed_genotypes:
                raise RuntimeError(
                    "observed genotypes don't match after "
                    f"collapse\n\tbefore: {observed_genotypes}"
                    f"\n\tafter: {final_observed_genotypes}\n\t"
                    "symmetric diff: "
                    f"{observed_genotypes ^ final_observed_genotypes}")
            assert sum(node.abundance for node in tree.traverse()) == sum(
                node.abundance for node in self.tree.traverse())

            rep_seq = sum(
                node.abundance > 0 for node in self.tree.traverse()) - len(
                    set([
                        node.sequence
                        for node in self.tree.traverse() if node.abundance > 0
                    ]))
            if not allow_repeats and rep_seq:
                raise RuntimeError(
                    "Repeated observed sequences in collapsed "
                    f"tree. {rep_seq} sequences were found repeated.")
            elif allow_repeats and rep_seq:
                rep_seq = sum(node.abundance > 0
                              for node in self.tree.traverse()) - len(
                                  set([
                                      node.sequence
                                      for node in self.tree.traverse()
                                      if node.abundance > 0
                                  ]))
                print("Repeated observed sequences in collapsed tree. "
                      f"{rep_seq} sequences were found repeated.")
            # a custom ladderize accounting for abundance and sequence to break
            # ties in abundance
            for node in self.tree.traverse(strategy="postorder"):
                # add a partition feature and compute it recursively up tree
                node.add_feature(
                    "partition",
                    node.abundance + sum(node2.partition
                                         for node2 in node.children),
                )
                # sort children of this node based on partion and sequence
                node.children.sort(
                    key=lambda node: (node.partition, node.sequence))

            # create list of (c, m) for each node
            self._cm_list = [(node.abundance, len(node.children))
                             for node in self.tree.traverse()]
            # store max c and m
            self._c_max = max(node.abundance for node in self.tree.traverse())
            self._m_max = max(
                len(node.children) for node in self.tree.traverse())
        else:
            self.tree = tree
Exemplo n.º 3
0
def treeprint(tree: ete3.TreeNode):
    tree = tree.copy()
    for node in tree.traverse():
        node.name = node.sequence
    return tree.write(format=8)
Exemplo n.º 4
0
class CollapsedTree(LeavesAndClades):
    '''
    Here's a derived class for a collapsed tree, where we recurse into the mutant clades
          (4)
         / | \\
       (3)(1)(2)
           |   \\
          (2)  (1)
    '''
    def __init__(self,
                 params=None,
                 tree=None,
                 frame=None,
                 collapse_syn=False,
                 allow_repeats=False):
        '''
        For intialization, either params or tree (or both) must be provided
        params: offspring distribution parameters
        tree: ete tree with frequency node feature. If uncollapsed, it will be collapsed
        frame: tranlation frame, with default None, no tranlation attempted
        '''
        LeavesAndClades.__init__(self, params=params)
        if frame is not None and frame not in (1, 2, 3):
            raise RuntimeError('frame must be 1, 2, 3, or None')
        self.frame = frame

        if collapse_syn is True:
            tree.dist = 0  # no branch above root
            for node in tree.iter_descendants():
                aa = Seq(
                    node.sequence[(frame - 1):(frame - 1 +
                                               (3 * (((len(node.sequence) -
                                                       (frame - 1)) // 3))))],
                    generic_dna).translate()
                aa_parent = Seq(
                    node.up.sequence[(frame - 1):(frame - 1 +
                                                  (3 *
                                                   (((len(node.sequence) -
                                                      (frame - 1)) // 3))))],
                    generic_dna).translate()
                node.dist = hamming_distance(aa, aa_parent)

        if tree is not None:
            self.tree = tree.copy()
            # remove unobserved internal unifurcations
            for node in self.tree.iter_descendants():
                parent = node.up
                if node.frequency == 0 and len(node.children) == 1:
                    node.delete(prevent_nondicotomic=False)
                    node.children[0].dist = hamming_distance(
                        node.children[0].sequence, parent.sequence)

            # iterate over the tree below root and collapse edges of zero length
            # if the node is a leaf and it's parent has nonzero frequency we combine taxa names to a set
            # this acommodates bootstrap samples that result in repeated genotypes
            observed_genotypes = set((leaf.name for leaf in self.tree))
            observed_genotypes.add(self.tree.name)
            for node in self.tree.get_descendants(strategy='postorder'):
                if node.dist == 0:
                    node.up.frequency += node.frequency
                    node_set = set([node.name]) if isinstance(
                        node.name, str) else set(node.name)
                    node_up_set = set([node.up.name]) if isinstance(
                        node.up.name, str) else set(node.up.name)
                    if node_up_set < observed_genotypes:
                        if node_set < observed_genotypes:
                            node.up.name = tuple(node_set | node_up_set)
                            if len(node.up.name) == 1:
                                node.up.name = node.up.name[0]
                    elif node_set < observed_genotypes:
                        node.up.name = tuple(node_set)
                        if len(node.up.name) == 1:
                            node.up.name = node.up.name[0]
                    node.delete(prevent_nondicotomic=False)

            final_observed_genotypes = set([
                name for node in self.tree.traverse()
                if node.frequency > 0 or node == self.tree for name in ((
                    node.name, ) if isinstance(node.name, str) else node.name)
            ])
            if final_observed_genotypes != observed_genotypes:
                raise RuntimeError(
                    'observed genotypes don\'t match after collapse\n\tbefore: {}\n\tafter: {}\n\tsymmetric diff: {}'
                    .format(observed_genotypes, final_observed_genotypes,
                            observed_genotypes ^ final_observed_genotypes))
            assert sum(node.frequency for node in tree.traverse()) == sum(
                node.frequency for node in self.tree.traverse())

            rep_seq = sum(
                node.frequency > 0 for node in self.tree.traverse()) - len(
                    set([
                        node.sequence
                        for node in self.tree.traverse() if node.frequency > 0
                    ]))
            if not allow_repeats and rep_seq:
                raise RuntimeError(
                    'Repeated observed sequences in collapsed tree. {} sequences were found repeated.'
                    .format(rep_seq))
            elif allow_repeats and rep_seq:
                rep_seq = sum(node.frequency > 0
                              for node in self.tree.traverse()) - len(
                                  set([
                                      node.sequence
                                      for node in self.tree.traverse()
                                      if node.frequency > 0
                                  ]))
                print(
                    'Repeated observed sequences in collapsed tree. {} sequences were found repeated.'
                    .format(rep_seq))
            # a custom ladderize accounting for abundance and sequence to break ties in abundance
            for node in self.tree.traverse(strategy='postorder'):
                # add a partition feature and compute it recursively up the tree
                node.add_feature(
                    'partition',
                    node.frequency + sum(node2.partition
                                         for node2 in node.children))
                # sort children of this node based on partion and sequence
                node.children.sort(
                    key=lambda node: (node.partition, node.sequence))
        else:
            self.tree = tree

    def l(self, params, sign=1):
        '''
        log likelihood of params, conditioned on collapsed tree, and its gradient wrt params
        optional parameter sign must be 1 or -1, with the latter useful for MLE by minimization
        '''
        if self.tree is None:
            raise ValueError('tree data must be defined to compute likelihood')
        if sign not in (-1, 1):
            raise ValueError('sign must be 1 or -1')
        leaves_and_clades_list = [
            LeavesAndClades(c=node.frequency, m=len(node.children))
            for node in self.tree.traverse()
        ]
        if leaves_and_clades_list[0].c == 0 and leaves_and_clades_list[
                0].m == 1 and leaves_and_clades_list[0].f(params)[0] == 0:
            # if unifurcation not possible under current model, add a psuedocount for the naive
            leaves_and_clades_list[0].c = 1
        # extract vector of function values and gradient components
        f_data = [
            leaves_and_clades.f(params)
            for leaves_and_clades in leaves_and_clades_list
        ]
        fs = scipy.array([[x[0]] for x in f_data])
        logf = scipy.log(fs).sum()
        grad_fs = scipy.array([x[1] for x in f_data])
        grad_logf = (grad_fs / fs).sum(axis=0)
        return sign * logf, sign * grad_logf

    def mle(self, **kwargs):
        '''
        Maximum likelihood estimate for params given tree
        updates params if not None
        returns optimization result
        '''
        # random initalization
        x_0 = (random.random(), random.random())
        bounds = ((.01, .99), (.001, .999))
        kwargs['sign'] = -1
        grad_check = check_grad(lambda x: self.l(x, **kwargs)[0],
                                lambda x: self.l(x, **kwargs)[1], (.4, .5))
        if grad_check > 1e-3:
            warnings.warn(
                'gradient mismatches finite difference approximation by {}'.
                format(grad_check), RuntimeWarning)
        result = minimize(lambda x: self.l(x, **kwargs),
                          x0=x_0,
                          jac=True,
                          method='L-BFGS-B',
                          options={'ftol': 1e-10},
                          bounds=bounds)
        # update params if None and optimization successful
        if not result.success:
            warnings.warn('optimization not sucessful, ' + result.message,
                          RuntimeWarning)
        elif self.params is None:
            self.params = result.x.tolist()
        return result

    def simulate(self):
        '''
        simulate a collapsed tree given params
        replaces existing tree data member with simulation result, and returns self
        '''
        if self.params is None:
            raise ValueError('params must be defined for simulation')

        # initiate by running a LeavesAndClades simulation to get the number of clones and mutants
        # in the root node of the collapsed tree
        LeavesAndClades.simulate(self)
        self.tree = TreeNode()
        self.tree.add_feature('frequency', self.c)
        if self.m == 0:
            return self
        for _ in range(self.m):
            # ooooh, recursion
            child = CollapsedTree(params=self.params,
                                  frame=self.frame).simulate().tree
            child.dist = 1
            self.tree.add_child(child)

        return self

    def __str__(self):
        '''return a string representation for printing'''
        return 'params = ' + str(self.params) + '\ntree:\n' + str(self.tree)

    def render(self,
               outfile,
               idlabel=False,
               colormap=None,
               show_support=False,
               chain_split=None):
        '''render to image file, filetype inferred from suffix, svg for color images'''
        def my_layout(node):
            circle_color = 'lightgray' if colormap is None or node.name not in colormap else colormap[
                node.name]
            text_color = 'black'
            if isinstance(circle_color, str):
                C = CircleFace(radius=max(3, 10 * scipy.sqrt(node.frequency)),
                               color=circle_color,
                               label={
                                   'text': str(node.frequency),
                                   'color': text_color
                               } if node.frequency > 0 else None)
                C.rotation = -90
                C.hz_align = 1
                faces.add_face_to_node(C, node, 0)
            else:
                P = PieChartFace(
                    [100 * x / node.frequency for x in circle_color.values()],
                    2 * 10 * scipy.sqrt(node.frequency),
                    2 * 10 * scipy.sqrt(node.frequency),
                    colors=[(color if color != 'None' else 'lightgray')
                            for color in list(circle_color.keys())],
                    line_color=None)
                T = TextFace(' '.join(
                    [str(x) for x in list(circle_color.values())]),
                             tight_text=True)
                T.hz_align = 1
                T.rotation = -90
                faces.add_face_to_node(P, node, 0, position='branch-right')
                faces.add_face_to_node(T, node, 1, position='branch-right')
            if idlabel:
                T = TextFace(node.name, tight_text=True, fsize=6)
                T.rotation = -90
                T.hz_align = 1
                faces.add_face_to_node(
                    T,
                    node,
                    1 if isinstance(circle_color, str) else 2,
                    position='branch-right')

        for node in self.tree.traverse():
            nstyle = NodeStyle()
            nstyle['size'] = 0
            if node.up is not None:
                if set(node.sequence.upper()) == set('ACGT'):
                    if chain_split is not None:
                        if self.frame is not None:
                            raise NotImplementedError(
                                'frame not implemented with chain_split')
                        leftseq_mutated = hamming_distance(
                            node.sequence[:chain_split],
                            node.up.sequence[:chain_split]) > 0
                        rightseq_mutated = hamming_distance(
                            node.sequence[chain_split:],
                            node.up.sequence[chain_split:]) > 0
                        if leftseq_mutated and rightseq_mutated:
                            nstyle['hz_line_color'] = 'purple'
                            nstyle['hz_line_width'] = 3
                        elif leftseq_mutated:
                            nstyle['hz_line_color'] = 'red'
                            nstyle['hz_line_width'] = 2
                        elif rightseq_mutated:
                            nstyle['hz_line_color'] = 'blue'
                            nstyle['hz_line_width'] = 2
                    if self.frame is not None:
                        aa = Seq(
                            node.sequence[(self.frame -
                                           1):(self.frame - 1 +
                                               (3 *
                                                (((len(node.sequence) -
                                                   (self.frame - 1)) // 3))))],
                            generic_dna).translate()
                        aa_parent = Seq(
                            node.up.sequence[(self.frame -
                                              1):(self.frame - 1 + (3 * ((
                                                  (len(node.sequence) -
                                                   (self.frame - 1)) // 3))))],
                            generic_dna).translate()
                        nonsyn = hamming_distance(aa, aa_parent)
                        if '*' in aa:
                            nstyle['bgcolor'] = 'red'
                        if nonsyn > 0:
                            nstyle['hz_line_color'] = 'black'
                            nstyle['hz_line_width'] = nonsyn
                        else:
                            nstyle['hz_line_type'] = 1
            node.set_style(nstyle)

        ts = TreeStyle()
        ts.show_leaf_name = False
        ts.rotation = 90
        ts.draw_aligned_faces_as_table = False
        ts.allow_face_overlap = True
        ts.layout_fn = my_layout
        ts.show_scale = False
        ts.show_branch_support = show_support
        self.tree.render(outfile, tree_style=ts)
        # if we labelled seqs, let's also write the alignment out so we have the sequences (including of internal nodes)
        if idlabel:
            aln = MultipleSeqAlignment([])
            for node in self.tree.traverse():
                aln.append(
                    SeqRecord(Seq(str(node.sequence), generic_dna),
                              id=str(node.name),
                              description='abundance={}'.format(
                                  node.frequency)))
            AlignIO.write(aln,
                          open(os.path.splitext(outfile)[0] + '.fasta', 'w'),
                          'fasta')

    def write(self, file_name):
        '''serialize tree to file'''
        with open(file_name, 'wb') as f:
            pickle.dump(self, f)

    def compare(self, tree2, method='identity'):
        '''compare this tree to the other tree'''
        if method == 'identity':
            # we compare lists of seq, parent, abundance
            # return true if these lists are identical, else false
            list1 = sorted((node.sequence, node.frequency,
                            node.up.sequence if node.up is not None else None)
                           for node in self.tree.traverse())
            list2 = sorted((node.sequence, node.frequency,
                            node.up.sequence if node.up is not None else None)
                           for node in tree2.tree.traverse())
            return list1 == list2
        elif method == 'MRCA':
            # matrix of hamming distance of common ancestors of taxa
            # takes a true and inferred tree as CollapsedTree objects
            taxa = [
                node.sequence for node in self.tree.traverse()
                if node.frequency
            ]
            n_taxa = len(taxa)
            d = scipy.zeros(shape=(n_taxa, n_taxa))
            sum_sites = scipy.zeros(shape=(n_taxa, n_taxa))
            for i in range(n_taxa):
                nodei_true = self.tree.iter_search_nodes(
                    sequence=taxa[i]).next()
                nodei = tree2.tree.iter_search_nodes(sequence=taxa[i]).next()
                for j in range(i + 1, n_taxa):
                    nodej_true = self.tree.iter_search_nodes(
                        sequence=taxa[j]).next()
                    nodej = tree2.tree.iter_search_nodes(
                        sequence=taxa[j]).next()
                    MRCA_true = self.tree.get_common_ancestor(
                        (nodei_true, nodej_true)).sequence
                    MRCA = tree2.tree.get_common_ancestor(
                        (nodei, nodej)).sequence
                    d[i, j] = hamming_distance(MRCA_true, MRCA)
                    sum_sites[i, j] = len(MRCA_true)
            return d.sum() / sum_sites.sum()
        elif method == 'RF':
            tree1_copy = self.tree.copy(method='deepcopy')
            tree2_copy = tree2.tree.copy(method='deepcopy')
            for treex in (tree1_copy, tree2_copy):
                for node in list(treex.traverse()):
                    if node.frequency > 0:
                        child = TreeNode()
                        child.add_feature('sequence', node.sequence)
                        node.add_child(child)
            try:
                return tree1_copy.robinson_foulds(tree2_copy,
                                                  attr_t1='sequence',
                                                  attr_t2='sequence',
                                                  unrooted_trees=True)[0]
            except:
                return tree1_copy.robinson_foulds(tree2_copy,
                                                  attr_t1='sequence',
                                                  attr_t2='sequence',
                                                  unrooted_trees=True,
                                                  allow_dup=True)[0]
        else:
            raise ValueError('invalid distance method: ' + method)

    def get_split(self, node):
        '''return the bipartition resulting from clipping this node's edge above'''
        if node.get_tree_root() != self.tree:
            raise ValueError('node not found')
        if node == self.tree:
            raise ValueError('this node is the root (no split above)')
        parent = node.up
        taxa1 = []
        for node2 in node.traverse():
            if node2.frequency > 0 or node2 == self.tree:
                if isinstance(node2.name, str):
                    taxa1.append(node2.name)
                else:
                    taxa1.extend(node2.name)
        taxa1 = set(taxa1)
        node.detach()
        taxa2 = []
        for node2 in self.tree.traverse():
            if node2.frequency > 0 or node2 == self.tree:
                if isinstance(node2.name, str):
                    taxa2.append(node2.name)
                else:
                    taxa2.extend(node2.name)
        taxa2 = set(taxa2)
        parent.add_child(node)
        assert taxa1.isdisjoint(taxa2)
        assert taxa1.union(taxa2) == set(
            (name for node in self.tree.traverse()
             if node.frequency > 0 or node == self.tree for name in ((
                 node.name, ) if isinstance(node.name, str) else node.name)))
        return tuple(sorted([taxa1, taxa2]))

    @staticmethod
    def split_compatibility(split1, split2):
        diff = split1[0].union(split1[1]) ^ split2[0].union(split2[1])
        if diff:
            raise ValueError(
                'splits do not cover the same taxa\n\ttaxa not in both: {}'.
                format(diff))
        for partition1 in split1:
            for partition2 in split2:
                if partition1.isdisjoint(partition2):
                    return True
        return False

    def support(self, bootstrap_trees_list, weights=None, compatibility=False):
        '''
        compute support from a list of bootstrap GCtrees
        weights (optional) is needed for weighting parsimony degenerate trees
        compatibility mode counts trees that don't disconfirm the split
        '''
        for node in self.tree.get_descendants():
            split = self.get_split(node)
            support = 0
            compatibility_ = 0
            for i, tree in enumerate(bootstrap_trees_list):
                compatible = True
                supported = False
                for boot_node in tree.tree.get_descendants():
                    boot_split = tree.get_split(boot_node)
                    if compatibility and compatible and not self.split_compatibility(
                            split, boot_split):
                        compatible = False
                    if not compatibility and not supported and boot_split == split:
                        supported = True
                if supported:
                    support += weights[i] if weights is not None else 1
                if compatible:
                    compatibility_ += weights[i] if weights is not None else 1
            node.support = compatibility_ if compatibility else support

        return self