Ejemplo n.º 1
0
        def _build(parse_node, tree_node=None):
            if tree_node is None:
                tree_node = TreeNode()

            if isinstance(parse_node, list):
                print(parse_node)

            if isinstance(parse_node, LeafNode):
                symbol = parse_node.literal
                token = parse_node.token

                tree_node.name = symbol
                tree_node.add_feature("tokens", [token])
            elif isinstance(parse_node, InternalNode):
                symbol = parse_node.symbol
                rule = parse_node.rule
                children = parse_node.children
                token = []

                for child_node in children:
                    node = _build(child_node)
                    tree_node.add_child(node)
                    token.extend(node.tokens)

                tree_node.name = symbol
                tree_node.add_feature("rule", rule)
                tree_node.add_feature("tokens", token)

            return tree_node
Ejemplo n.º 2
0
def addDeadLineage(spTree):
    """ 
    Takes:
        - spTree (ete3.Tree) : species tree

    Returns:
        (ete3.Tree) : same tree with a dead lineage (name "-1") as outgroup
                      AND all nodes have a "dead" feature (bool that is True only for the dead lineage and the new root)
    """

    newSpTree = deepcopy(spTree)
    newSpTree.dist = 0.1

    for n in newSpTree.traverse():
        n.add_feature("dead",False)

    newRoot = TreeNode()
    newRoot.add_feature("dead",True)
    newRoot.dist = 0.0


    newRoot.add_child(newSpTree)
    rootHeight = newRoot.get_distance(newRoot.get_leaves()[0])

    deadLineage = TreeNode()
    deadLineage.add_feature("dead",True)    
    deadLineage.name = "-1"
    deadLineage.dist  = rootHeight

    newRoot.add_child(deadLineage)

    return newRoot
Ejemplo n.º 3
0
    def p_toArbre(self):

        n = TreeNode()
        n.name = "main()"

        n1 = TreeNode()
        n1.name = str(self.sons[0])

        n2 = self.sons[1].c_toArbre()

        n3 = self.sons[2].e_toArbre()

        n.add_child(n1)
        n.add_child(n2)
        n.add_child(n3)

        return n
Ejemplo n.º 4
0
    def e_toArbre(self):

        if self.type == "NUMBER":
            n = TreeNode()
            n.name = "Number : " + str(self.value)
            return n

        elif self.type == "ID":
            n = TreeNode()
            n.name = "Id : " + self.value
            return n

        elif self.type == "OPBIN":
            n = TreeNode()
            n.name = self.value
            n1 = self.sons[0].e_toArbre()
            n2 = self.sons[1].e_toArbre()
            n.add_child(n1)
            n.add_child(n2)

            return n
Ejemplo n.º 5
0
    def c_toArbre(self):

        if self.value == "=":
            n = TreeNode()
            n.name = self.value

            n1 = TreeNode()
            n1.name = "Id : " + self.sons[0]
            n2 = self.sons[1].e_toArbre()
            n.add_child(n1)
            n.add_child(n2)

            return n

        elif self.value == ';':
            n = TreeNode()
            n.name = self.value

            n1 = self.sons[0].c_toArbre()
            n2 = self.sons[1].c_toArbre()
            n.add_child(n1)
            n.add_child(n2)

            return n

        else:
            n = TreeNode()
            n.name = self.value

            n1 = self.sons[0].e_toArbre()
            n2 = self.sons[1].c_toArbre()
            n.add_child(n1)
            n.add_child(n2)

            return n
Ejemplo n.º 6
0
    def add_tree_to_distribution(self, tree):
        """
        Add the bipartition of a tree to the CCP distribution
        
        Takes:
            - tree (ete3.Tree): phylogenetic tree
            
        """

        if len(tree.children) == 3:
            ## special unrroted case where the tree begin by a trifurcation ...
            ## we artificially remove the trifurcation to avoid future problems
            a = TreeNode()
            b = tree.children[1]
            c = tree.children[2]
            b.detach()
            c.detach()
            tree.add_child(a)
            a.add_child(b)
            a.add_child(c)
            #print " special rerooting "

        for i in tree.traverse():
            if len(i.children) > 2:
                print "multifurcation detected! Please provide bifurcating trees."
                print "exiting now"
                exit(1)

        if self.nb_observation == 0:  ##no tree has been observed yet: add all the leaves
            for l in tree.get_leaf_names():
                self.get_leaf_id(l)  ##adds the leaves to the CCP

        for node in tree.traverse("postorder"):  ##for each branch of the tree
            self.add_tree_branch_to_distribution(node)

        self.nb_observation += 1

        return
Ejemplo n.º 7
0
def subdivideSpTree(spTree):
    """
    Takes:
        - spTree (ete3.Tree) : an ULTRAMETRIC species tree

    Returns:
        (ete3.Tree) : subdivided species tree where all nodes have a timeSlice feature
        or
        None if the species tree is not ultrametric
    """
    newSpTree = deepcopy(spTree)

    featureName = "timeSlice"

    ##1/ getting distance from root.
    Dheight = getDistFromRootDic(newSpTree , checkUltrametric = True)

    if Dheight is None:
        print "!!ERROR!! : the species tree is not ultrametric"
        return None

    # we know that there is n-1 internal nodes (where n is the number of leaves)
    # hence the maximal timeSlice is n-1 (all leaves have timeSlice 0)

    ##2/assign timeSlice to nodes
    currentTS = len(newSpTree.get_leaves()) - 1


    for n,h in sorted(Dheight.iteritems(), key=lambda (k,v): (v,k)):
        n.add_feature(featureName, currentTS )

        if currentTS != 0:
            currentTS -= 1


    #print newSpTree.get_ascii(attributes=[featureName,"name"])

    ##3/subdivide according to timeSlice
    RealNodes = [i for i in  newSpTree.traverse()]

    for n in RealNodes:
        if n.is_root():
            continue

        nodeToAdd = n.up.timeSlice - n.timeSlice - 1

        while nodeToAdd > 0:
            parentNode = n.up
            
            n.detach()
            
            NullNode = TreeNode()
            NullNode.add_feature( featureName, parentNode.timeSlice - 1 )

            if "dead" in n.features:
                NullNode.add_feature("dead" , n.dead)

            parentNode.add_child(NullNode)
            NullNode.add_child(n)
            nodeToAdd -= 1 

    #print newSpTree.get_ascii(attributes=[featureName,"name"])
    return newSpTree
Ejemplo n.º 8
0
        if parent in nodes:
            #parent = G.search_nodes(name = _parent)
            parentNode = nodes[parent]
        else:
            parentNode = TreeNode(name=parent)
            parentNode.set_style(nstyle)
            #parentNode.add_face(TextFace(_parent), column=0, position="aligned")
            #faces.add_face_to_node(TextFace(parent), parentNode, 0, position="aligned")
            nodes[parent] = parentNode

        #child = G.search_nodes(name = node)
        if node in nodes:
            childNode = nodes[node]
        else:
            childNode = parentNode.add_child(name=node)
            childNode.set_style(nstyle)
            #childNode.add_face(TextFace(node), column=0, position="aligned")
            #faces.add_face_to_node(TextFace(node), childNode, 0, position="aligned")
            nodes[node] = childNode

print(G)
#print(nodes)

for n in G.traverse():
    if (len(n.get_ancestors()) >= 4):
        n.dist = 0.1
    else:
        n.dist = 1.0

ts = TreeStyle()
Ejemplo n.º 9
0
class CollapsedTree(LeavesAndClades):
    '''
    Here's a derived class for a collapsed tree, where we recurse into the mutant clades
          (4)
         / | \\
       (3)(1)(2)
           |   \\
          (2)  (1)
    '''
    def __init__(self,
                 params=None,
                 tree=None,
                 frame=None,
                 collapse_syn=False,
                 allow_repeats=False):
        '''
        For intialization, either params or tree (or both) must be provided
        params: offspring distribution parameters
        tree: ete tree with frequency node feature. If uncollapsed, it will be collapsed
        frame: tranlation frame, with default None, no tranlation attempted
        '''
        LeavesAndClades.__init__(self, params=params)
        if frame is not None and frame not in (1, 2, 3):
            raise RuntimeError('frame must be 1, 2, 3, or None')
        self.frame = frame

        if collapse_syn is True:
            tree.dist = 0  # no branch above root
            for node in tree.iter_descendants():
                aa = Seq(
                    node.sequence[(frame - 1):(frame - 1 +
                                               (3 * (((len(node.sequence) -
                                                       (frame - 1)) // 3))))],
                    generic_dna).translate()
                aa_parent = Seq(
                    node.up.sequence[(frame - 1):(frame - 1 +
                                                  (3 *
                                                   (((len(node.sequence) -
                                                      (frame - 1)) // 3))))],
                    generic_dna).translate()
                node.dist = hamming_distance(aa, aa_parent)

        if tree is not None:
            self.tree = tree.copy()
            # remove unobserved internal unifurcations
            for node in self.tree.iter_descendants():
                parent = node.up
                if node.frequency == 0 and len(node.children) == 1:
                    node.delete(prevent_nondicotomic=False)
                    node.children[0].dist = hamming_distance(
                        node.children[0].sequence, parent.sequence)

            # iterate over the tree below root and collapse edges of zero length
            # if the node is a leaf and it's parent has nonzero frequency we combine taxa names to a set
            # this acommodates bootstrap samples that result in repeated genotypes
            observed_genotypes = set((leaf.name for leaf in self.tree))
            observed_genotypes.add(self.tree.name)
            for node in self.tree.get_descendants(strategy='postorder'):
                if node.dist == 0:
                    node.up.frequency += node.frequency
                    node_set = set([node.name]) if isinstance(
                        node.name, str) else set(node.name)
                    node_up_set = set([node.up.name]) if isinstance(
                        node.up.name, str) else set(node.up.name)
                    if node_up_set < observed_genotypes:
                        if node_set < observed_genotypes:
                            node.up.name = tuple(node_set | node_up_set)
                            if len(node.up.name) == 1:
                                node.up.name = node.up.name[0]
                    elif node_set < observed_genotypes:
                        node.up.name = tuple(node_set)
                        if len(node.up.name) == 1:
                            node.up.name = node.up.name[0]
                    node.delete(prevent_nondicotomic=False)

            final_observed_genotypes = set([
                name for node in self.tree.traverse()
                if node.frequency > 0 or node == self.tree for name in ((
                    node.name, ) if isinstance(node.name, str) else node.name)
            ])
            if final_observed_genotypes != observed_genotypes:
                raise RuntimeError(
                    'observed genotypes don\'t match after collapse\n\tbefore: {}\n\tafter: {}\n\tsymmetric diff: {}'
                    .format(observed_genotypes, final_observed_genotypes,
                            observed_genotypes ^ final_observed_genotypes))
            assert sum(node.frequency for node in tree.traverse()) == sum(
                node.frequency for node in self.tree.traverse())

            rep_seq = sum(
                node.frequency > 0 for node in self.tree.traverse()) - len(
                    set([
                        node.sequence
                        for node in self.tree.traverse() if node.frequency > 0
                    ]))
            if not allow_repeats and rep_seq:
                raise RuntimeError(
                    'Repeated observed sequences in collapsed tree. {} sequences were found repeated.'
                    .format(rep_seq))
            elif allow_repeats and rep_seq:
                rep_seq = sum(node.frequency > 0
                              for node in self.tree.traverse()) - len(
                                  set([
                                      node.sequence
                                      for node in self.tree.traverse()
                                      if node.frequency > 0
                                  ]))
                print(
                    'Repeated observed sequences in collapsed tree. {} sequences were found repeated.'
                    .format(rep_seq))
            # a custom ladderize accounting for abundance and sequence to break ties in abundance
            for node in self.tree.traverse(strategy='postorder'):
                # add a partition feature and compute it recursively up the tree
                node.add_feature(
                    'partition',
                    node.frequency + sum(node2.partition
                                         for node2 in node.children))
                # sort children of this node based on partion and sequence
                node.children.sort(
                    key=lambda node: (node.partition, node.sequence))
        else:
            self.tree = tree

    def l(self, params, sign=1):
        '''
        log likelihood of params, conditioned on collapsed tree, and its gradient wrt params
        optional parameter sign must be 1 or -1, with the latter useful for MLE by minimization
        '''
        if self.tree is None:
            raise ValueError('tree data must be defined to compute likelihood')
        if sign not in (-1, 1):
            raise ValueError('sign must be 1 or -1')
        leaves_and_clades_list = [
            LeavesAndClades(c=node.frequency, m=len(node.children))
            for node in self.tree.traverse()
        ]
        if leaves_and_clades_list[0].c == 0 and leaves_and_clades_list[
                0].m == 1 and leaves_and_clades_list[0].f(params)[0] == 0:
            # if unifurcation not possible under current model, add a psuedocount for the naive
            leaves_and_clades_list[0].c = 1
        # extract vector of function values and gradient components
        f_data = [
            leaves_and_clades.f(params)
            for leaves_and_clades in leaves_and_clades_list
        ]
        fs = scipy.array([[x[0]] for x in f_data])
        logf = scipy.log(fs).sum()
        grad_fs = scipy.array([x[1] for x in f_data])
        grad_logf = (grad_fs / fs).sum(axis=0)
        return sign * logf, sign * grad_logf

    def mle(self, **kwargs):
        '''
        Maximum likelihood estimate for params given tree
        updates params if not None
        returns optimization result
        '''
        # random initalization
        x_0 = (random.random(), random.random())
        bounds = ((.01, .99), (.001, .999))
        kwargs['sign'] = -1
        grad_check = check_grad(lambda x: self.l(x, **kwargs)[0],
                                lambda x: self.l(x, **kwargs)[1], (.4, .5))
        if grad_check > 1e-3:
            warnings.warn(
                'gradient mismatches finite difference approximation by {}'.
                format(grad_check), RuntimeWarning)
        result = minimize(lambda x: self.l(x, **kwargs),
                          x0=x_0,
                          jac=True,
                          method='L-BFGS-B',
                          options={'ftol': 1e-10},
                          bounds=bounds)
        # update params if None and optimization successful
        if not result.success:
            warnings.warn('optimization not sucessful, ' + result.message,
                          RuntimeWarning)
        elif self.params is None:
            self.params = result.x.tolist()
        return result

    def simulate(self):
        '''
        simulate a collapsed tree given params
        replaces existing tree data member with simulation result, and returns self
        '''
        if self.params is None:
            raise ValueError('params must be defined for simulation')

        # initiate by running a LeavesAndClades simulation to get the number of clones and mutants
        # in the root node of the collapsed tree
        LeavesAndClades.simulate(self)
        self.tree = TreeNode()
        self.tree.add_feature('frequency', self.c)
        if self.m == 0:
            return self
        for _ in range(self.m):
            # ooooh, recursion
            child = CollapsedTree(params=self.params,
                                  frame=self.frame).simulate().tree
            child.dist = 1
            self.tree.add_child(child)

        return self

    def __str__(self):
        '''return a string representation for printing'''
        return 'params = ' + str(self.params) + '\ntree:\n' + str(self.tree)

    def render(self,
               outfile,
               idlabel=False,
               colormap=None,
               show_support=False,
               chain_split=None):
        '''render to image file, filetype inferred from suffix, svg for color images'''
        def my_layout(node):
            circle_color = 'lightgray' if colormap is None or node.name not in colormap else colormap[
                node.name]
            text_color = 'black'
            if isinstance(circle_color, str):
                C = CircleFace(radius=max(3, 10 * scipy.sqrt(node.frequency)),
                               color=circle_color,
                               label={
                                   'text': str(node.frequency),
                                   'color': text_color
                               } if node.frequency > 0 else None)
                C.rotation = -90
                C.hz_align = 1
                faces.add_face_to_node(C, node, 0)
            else:
                P = PieChartFace(
                    [100 * x / node.frequency for x in circle_color.values()],
                    2 * 10 * scipy.sqrt(node.frequency),
                    2 * 10 * scipy.sqrt(node.frequency),
                    colors=[(color if color != 'None' else 'lightgray')
                            for color in list(circle_color.keys())],
                    line_color=None)
                T = TextFace(' '.join(
                    [str(x) for x in list(circle_color.values())]),
                             tight_text=True)
                T.hz_align = 1
                T.rotation = -90
                faces.add_face_to_node(P, node, 0, position='branch-right')
                faces.add_face_to_node(T, node, 1, position='branch-right')
            if idlabel:
                T = TextFace(node.name, tight_text=True, fsize=6)
                T.rotation = -90
                T.hz_align = 1
                faces.add_face_to_node(
                    T,
                    node,
                    1 if isinstance(circle_color, str) else 2,
                    position='branch-right')

        for node in self.tree.traverse():
            nstyle = NodeStyle()
            nstyle['size'] = 0
            if node.up is not None:
                if set(node.sequence.upper()) == set('ACGT'):
                    if chain_split is not None:
                        if self.frame is not None:
                            raise NotImplementedError(
                                'frame not implemented with chain_split')
                        leftseq_mutated = hamming_distance(
                            node.sequence[:chain_split],
                            node.up.sequence[:chain_split]) > 0
                        rightseq_mutated = hamming_distance(
                            node.sequence[chain_split:],
                            node.up.sequence[chain_split:]) > 0
                        if leftseq_mutated and rightseq_mutated:
                            nstyle['hz_line_color'] = 'purple'
                            nstyle['hz_line_width'] = 3
                        elif leftseq_mutated:
                            nstyle['hz_line_color'] = 'red'
                            nstyle['hz_line_width'] = 2
                        elif rightseq_mutated:
                            nstyle['hz_line_color'] = 'blue'
                            nstyle['hz_line_width'] = 2
                    if self.frame is not None:
                        aa = Seq(
                            node.sequence[(self.frame -
                                           1):(self.frame - 1 +
                                               (3 *
                                                (((len(node.sequence) -
                                                   (self.frame - 1)) // 3))))],
                            generic_dna).translate()
                        aa_parent = Seq(
                            node.up.sequence[(self.frame -
                                              1):(self.frame - 1 + (3 * ((
                                                  (len(node.sequence) -
                                                   (self.frame - 1)) // 3))))],
                            generic_dna).translate()
                        nonsyn = hamming_distance(aa, aa_parent)
                        if '*' in aa:
                            nstyle['bgcolor'] = 'red'
                        if nonsyn > 0:
                            nstyle['hz_line_color'] = 'black'
                            nstyle['hz_line_width'] = nonsyn
                        else:
                            nstyle['hz_line_type'] = 1
            node.set_style(nstyle)

        ts = TreeStyle()
        ts.show_leaf_name = False
        ts.rotation = 90
        ts.draw_aligned_faces_as_table = False
        ts.allow_face_overlap = True
        ts.layout_fn = my_layout
        ts.show_scale = False
        ts.show_branch_support = show_support
        self.tree.render(outfile, tree_style=ts)
        # if we labelled seqs, let's also write the alignment out so we have the sequences (including of internal nodes)
        if idlabel:
            aln = MultipleSeqAlignment([])
            for node in self.tree.traverse():
                aln.append(
                    SeqRecord(Seq(str(node.sequence), generic_dna),
                              id=str(node.name),
                              description='abundance={}'.format(
                                  node.frequency)))
            AlignIO.write(aln,
                          open(os.path.splitext(outfile)[0] + '.fasta', 'w'),
                          'fasta')

    def write(self, file_name):
        '''serialize tree to file'''
        with open(file_name, 'wb') as f:
            pickle.dump(self, f)

    def compare(self, tree2, method='identity'):
        '''compare this tree to the other tree'''
        if method == 'identity':
            # we compare lists of seq, parent, abundance
            # return true if these lists are identical, else false
            list1 = sorted((node.sequence, node.frequency,
                            node.up.sequence if node.up is not None else None)
                           for node in self.tree.traverse())
            list2 = sorted((node.sequence, node.frequency,
                            node.up.sequence if node.up is not None else None)
                           for node in tree2.tree.traverse())
            return list1 == list2
        elif method == 'MRCA':
            # matrix of hamming distance of common ancestors of taxa
            # takes a true and inferred tree as CollapsedTree objects
            taxa = [
                node.sequence for node in self.tree.traverse()
                if node.frequency
            ]
            n_taxa = len(taxa)
            d = scipy.zeros(shape=(n_taxa, n_taxa))
            sum_sites = scipy.zeros(shape=(n_taxa, n_taxa))
            for i in range(n_taxa):
                nodei_true = self.tree.iter_search_nodes(
                    sequence=taxa[i]).next()
                nodei = tree2.tree.iter_search_nodes(sequence=taxa[i]).next()
                for j in range(i + 1, n_taxa):
                    nodej_true = self.tree.iter_search_nodes(
                        sequence=taxa[j]).next()
                    nodej = tree2.tree.iter_search_nodes(
                        sequence=taxa[j]).next()
                    MRCA_true = self.tree.get_common_ancestor(
                        (nodei_true, nodej_true)).sequence
                    MRCA = tree2.tree.get_common_ancestor(
                        (nodei, nodej)).sequence
                    d[i, j] = hamming_distance(MRCA_true, MRCA)
                    sum_sites[i, j] = len(MRCA_true)
            return d.sum() / sum_sites.sum()
        elif method == 'RF':
            tree1_copy = self.tree.copy(method='deepcopy')
            tree2_copy = tree2.tree.copy(method='deepcopy')
            for treex in (tree1_copy, tree2_copy):
                for node in list(treex.traverse()):
                    if node.frequency > 0:
                        child = TreeNode()
                        child.add_feature('sequence', node.sequence)
                        node.add_child(child)
            try:
                return tree1_copy.robinson_foulds(tree2_copy,
                                                  attr_t1='sequence',
                                                  attr_t2='sequence',
                                                  unrooted_trees=True)[0]
            except:
                return tree1_copy.robinson_foulds(tree2_copy,
                                                  attr_t1='sequence',
                                                  attr_t2='sequence',
                                                  unrooted_trees=True,
                                                  allow_dup=True)[0]
        else:
            raise ValueError('invalid distance method: ' + method)

    def get_split(self, node):
        '''return the bipartition resulting from clipping this node's edge above'''
        if node.get_tree_root() != self.tree:
            raise ValueError('node not found')
        if node == self.tree:
            raise ValueError('this node is the root (no split above)')
        parent = node.up
        taxa1 = []
        for node2 in node.traverse():
            if node2.frequency > 0 or node2 == self.tree:
                if isinstance(node2.name, str):
                    taxa1.append(node2.name)
                else:
                    taxa1.extend(node2.name)
        taxa1 = set(taxa1)
        node.detach()
        taxa2 = []
        for node2 in self.tree.traverse():
            if node2.frequency > 0 or node2 == self.tree:
                if isinstance(node2.name, str):
                    taxa2.append(node2.name)
                else:
                    taxa2.extend(node2.name)
        taxa2 = set(taxa2)
        parent.add_child(node)
        assert taxa1.isdisjoint(taxa2)
        assert taxa1.union(taxa2) == set(
            (name for node in self.tree.traverse()
             if node.frequency > 0 or node == self.tree for name in ((
                 node.name, ) if isinstance(node.name, str) else node.name)))
        return tuple(sorted([taxa1, taxa2]))

    @staticmethod
    def split_compatibility(split1, split2):
        diff = split1[0].union(split1[1]) ^ split2[0].union(split2[1])
        if diff:
            raise ValueError(
                'splits do not cover the same taxa\n\ttaxa not in both: {}'.
                format(diff))
        for partition1 in split1:
            for partition2 in split2:
                if partition1.isdisjoint(partition2):
                    return True
        return False

    def support(self, bootstrap_trees_list, weights=None, compatibility=False):
        '''
        compute support from a list of bootstrap GCtrees
        weights (optional) is needed for weighting parsimony degenerate trees
        compatibility mode counts trees that don't disconfirm the split
        '''
        for node in self.tree.get_descendants():
            split = self.get_split(node)
            support = 0
            compatibility_ = 0
            for i, tree in enumerate(bootstrap_trees_list):
                compatible = True
                supported = False
                for boot_node in tree.tree.get_descendants():
                    boot_split = tree.get_split(boot_node)
                    if compatibility and compatible and not self.split_compatibility(
                            split, boot_split):
                        compatible = False
                    if not compatibility and not supported and boot_split == split:
                        supported = True
                if supported:
                    support += weights[i] if weights is not None else 1
                if compatible:
                    compatibility_ += weights[i] if weights is not None else 1
            node.support = compatibility_ if compatibility else support

        return self
Ejemplo n.º 10
0
    def simulate(
        self,
        sequence: str,
        seq_bounds: Tuple[Tuple[int, int], Tuple[int, int]] = None,
        fitness_function: Callable = lambda seq: 0.9,
        lambda0: List[np.float64] = [1],
        frame: int = None,
        N_init: int = 1,
        N: int = None,
        T: int = None,
        n: int = None,
        verbose: bool = False,
    ) -> TreeNode:
        r"""Simulate a neutral binary branching process with the mutation model, returning a :class:`ete3.Treenode` object.

        Args:
            sequence: root nucleotide sequence
            seq_bounds: ranges for two subsequences used as two parallel genes
            fitness_function: mean number offspring as a function of sequence
            lambda0: baseline mutation rate(s)
            frame: coding frame of starting position(s)
            N_init: initial naive abundnace
            N: maximum population size
            T: maximum generation time
            n: sample size
            verbose: print more messages
        """
        # Checking the validity of the input parameters:
        if N is not None and T is not None:
            raise ValueError(
                "Only one of N and T can be used. One must be None.")
        elif N is None and T is None:
            raise ValueError("Either N or T must be specified.")
        if N is not None and n is not None and n > N:
            raise ValueError("n ({}) must not larger than N ({})".format(n, N))

        # Planting the tree:
        tree = TreeNode()
        tree.dist = 0
        tree.add_feature("sequence", sequence)
        tree.add_feature("terminated", False)
        tree.add_feature("abundance", 0)
        tree.add_feature("time", 0)
        # add fitness attribute, interpreted as mean of offspring distribution
        tree.add_feature("fitness", fitness_function(tree.sequence))

        if N_init > 1:
            for _ in range(N_init):
                child = TreeNode()
                child.dist = 0
                child.add_feature("sequence", sequence)
                child.add_feature("abundance", 0)
                child.add_feature("terminated", False)
                child.add_feature("time", 0)
                # add fitness attribute, interpreted as mean of offspring distribution
                child.add_feature("fitness", fitness_function(child.sequence))
                tree.add_child(child)

        t = 0  # <-- time
        leaves_unterminated = N_init
        while (leaves_unterminated > 0
               and (leaves_unterminated < N if N is not None else True)
               and (t < max(T) if T is not None else True)):
            if verbose:
                print("At time:", t)
            t += 1
            list_of_leaves = list(tree.iter_leaves())
            random.shuffle(list_of_leaves)
            for leaf in list_of_leaves:
                # add fitness attribute, interpreted as mean of offspring distribution
                leaf.add_feature("fitness", fitness_function(leaf.sequence))
                if not leaf.terminated:
                    n_children = poisson(leaf.fitness).rvs()
                    leaves_unterminated += (
                        n_children - 1
                    )  # <-- this kills the parent if we drew a zero
                    if not n_children:
                        leaf.terminated = True
                    for child_count in range(n_children):
                        # If sequence pair mutate them separately with their own mutation rate:
                        if seq_bounds is not None:
                            mutated_sequence1 = self.mutate(
                                leaf.
                                sequence[seq_bounds[0][0]:seq_bounds[0][1]],
                                lambda0=lambda0[0],
                                frame=frame,
                            )
                            mutated_sequence2 = self.mutate(
                                leaf.
                                sequence[seq_bounds[1][0]:seq_bounds[1][1]],
                                lambda0=lambda0[1],
                                frame=frame,
                            )
                            mutated_sequence = mutated_sequence1 + mutated_sequence2
                        else:
                            mutated_sequence = self.mutate(leaf.sequence,
                                                           lambda0=lambda0[0],
                                                           frame=frame)
                        child = TreeNode()
                        child.dist = utils.hamming_distance(
                            mutated_sequence, leaf.sequence)
                        child.add_feature("sequence", mutated_sequence)
                        child.add_feature("abundance", 0)
                        child.add_feature("terminated", False)
                        child.add_feature("time", t)
                        leaf.add_child(child)

        if N is not None and leaves_unterminated < N:
            raise RuntimeError(
                "tree terminated with {} leaves, {} desired".format(
                    leaves_unterminated, N))

        # each leaf in final generation gets an observed abundance of 1, unless downsampled
        if T is not None and len(T) > 1:
            # Iterate the intermediate time steps:
            for Ti in sorted(T)[:-1]:
                # Only sample those that have been 'sampled' at intermediate sampling times:
                final_leaves = [
                    leaf for leaf in tree.iter_descendants()
                    if leaf.time == Ti and leaf.sampled
                ]
                if len(final_leaves) < n:
                    raise RuntimeError(
                        "tree terminated with {} leaves, less than what desired after downsampling {}"
                        .format(leaves_unterminated, n))
                for (leaf) in (
                        final_leaves
                ):  # No need to down-sample, this was already done in the simulation loop
                    leaf.abundance = 1
        # Do the normal sampling of the last time step:
        final_leaves = [leaf for leaf in tree.iter_leaves() if leaf.time == t]
        # by default, downsample to the target simulation size
        if n is not None and len(final_leaves) >= n:
            for leaf in random.sample(final_leaves, n):
                leaf.abundance = 1
        elif n is None and N is not None:
            for leaf in random.sample(final_leaves, N):
                leaf.abundance = 1
        elif N is None and T is not None:
            for leaf in final_leaves:
                leaf.abundance = 1
        elif n is not None and len(final_leaves) < n:
            raise RuntimeError(
                "tree terminated with {} leaves, less than what desired after downsampling {}"
                .format(leaves_unterminated, n))
        else:
            raise RuntimeError("Unknown option.")

        # prune away lineages that are unobserved
        for node in tree.iter_descendants():
            if sum(node2.abundance for node2 in node.traverse()) == 0:
                node.detach()

        # # remove unobserved unifurcations
        # for node in tree.iter_descendants():
        #     parent = node.up
        #     if node.abundance == 0 and len(node.children) == 1:
        #         node.delete(prevent_nondicotomic=False)
        #         node.children[0].dist = hamming_distance(node.children[0].sequence, parent.sequence)

        # assign unique names to each node
        for i, node in enumerate(tree.traverse(), 1):
            node.name = "simcell_{}".format(i)

        # return the uncollapsed tree
        return tree