Example #1
0
def make_oldctree(tree):
    """Make an old CollapsedTree from an hDAG clade tree"""
    etetree = tree.to_ete(
        name_func=lambda n: n.attr["name"],
        features=["sequence"],
        feature_funcs={"abundance": lambda n: n.attr["abundance"]},
    )
    for node in etetree.traverse():
        if not node.is_leaf():
            node.abundance = 0
    etetree.dist = 0
    for node in etetree.iter_descendants():
        node.dist = utils.hamming_distance(node.up.sequence, node.sequence)
    return OldCollapsedTree(etetree)
Example #2
0
def isotype_tree(
    tree: ete3.TreeNode,
    newidmap: Dict[str, Dict[str, str]],
    isotype_names: Sequence[str],
    weight_matrix: Optional[Sequence[Sequence[float]]] = None,
) -> ete3.TreeNode:
    """Method adds isotypes to ``tree``, minimizing isotype switching and
    obeying switching order.

    * Adds observed isotypes to each observed node in the collapsed
      trees output by gctree inference. If cells with the same sequence
      but different isotypes are observed, then collapsed tree nodes
      must be ‘exploded’ into new nodes with the appropriate isotypes
      and abundances. Each unique sequence ID generated by gctree is
      prepended to its observed isotype, and a new `isotyped.idmap`
      mapping these new sequence IDs to original sequence IDs is
      written in the output directory.
    * Resolves isotypes of unobserved ancestral genotypes in a way
      that minimizes isotype switching and obeys isotype switching
      order. If observed isotypes of an observed internal node and its
      children violate switching order, then the observed internal node
      is replaced with an unobserved node with the same sequence, and
      the observed internal node is placed as a child leaf. This
      procedure always allows switching order conflicts to be resolved,
      and should usually increase isotype transitions required in the
      resulting tree.

    Args:
        tree: ete3 Tree
        newidmap: mapping of sequence IDs to isotypes, such as that output by :meth:`utils.explode_idmap`.
        isotype_names: list or other sequence of isotype names observed, in correct switching order.

    Returns:
        A new ete3 Tree whose nodes have isotype annotations in the attribute ``isotype``.
        Node names in this tree also contain isotype names.
    """
    tree = tree.copy()
    _add_observed_isotypes(tree,
                           newidmap,
                           isotype_names,
                           weight_matrix=weight_matrix)
    _disambiguate_isotype(tree)
    _collapse_tree_by_sequence_and_isotype(tree)
    for node in tree.traverse():
        node.name = str(node.name) + " " + str(node.isotype)
    for node in tree.iter_descendants():
        node.dist = hamming_distance(node.up.sequence, node.sequence)
    return tree
Example #3
0
def align_lineages(seq,
                   tree_t,
                   tree_i,
                   gap_penalty_pct=0,
                   known_root=True,
                   allow_double_gap=False):
    """Standard implementation of a Needleman-Wunsch algorithm as described
    here: http://telliott99.blogspot.com/2009/08/alignment-needleman-
    wunsch.html
    https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm And
    implemented here: https://github.com/alevchuk/pairwise-alignment-in-
    python/blob/master/alignment.py.

    gap_penalty_pct is the gap penalty relative to the sequence length
    of the sequences on the tree.
    """
    nt = find_node_by_seq(tree_t, seq)
    lt = reconstruct_lineage(tree_t, nt)
    ni = find_node_by_seq(tree_i, seq)
    li = reconstruct_lineage(tree_i, ni)
    # One lineages must be longer than just the root and the terminal node
    if len(lt) <= 2 and len(li) <= 2:
        return False

    # Gap penalty chosen not too large:
    gap_penalty = -1 * int((len(seq) / 100.0) * gap_penalty_pct)
    assert gap_penalty <= 0  # Penalties must be negative
    if (
            gap_penalty == 0
    ):  # If gap penalty is zero only gaps in the shortes sequence will be allowed
        assert allow_double_gap is False

    # Generate a score matrix matrix:
    kt = len(lt)
    ki = len(li)
    # Disallow gaps in the longest list:
    if allow_double_gap is False and kt > ki:
        # If true is longer than inferred allow gap only in inferred:
        gap_penalty_i = gap_penalty
        gap_penalty_j = -1 * float("inf")
    elif allow_double_gap is False and kt < ki:
        # If inferred is longer than true allow gap only in true:
        gap_penalty_i = -1 * float("inf")
        gap_penalty_j = gap_penalty
    elif allow_double_gap is False and kt == ki:
        # If lists are equally long no gaps are allowed:
        gap_penalty_i = -1 * float("inf")
        gap_penalty_j = -1 * float("inf")
    else:
        gap_penalty_i = gap_penalty
        gap_penalty_j = gap_penalty

    sc_mat = np.zeros((kt, ki), dtype=np.float64)
    for i in range(kt):
        for j in range(ki):
            # Notice the score is defined by number of mismatches:
            # sc_mat[i, j] = len(lt[i]) - hamming_distance(lt[i], li[j])
            sc_mat[i, j] = -1 * hamming_distance(lt[i], li[j])

    ###    print(sc_mat)
    # Calculate the alignment scores:
    aln_sc = np.zeros((kt + 1, ki + 1), dtype=np.float64)
    for i in range(0, kt + 1):
        if known_root is True:
            aln_sc[i][0] = -1 * float("inf")
        else:
            aln_sc[i][0] = gap_penalty_i * i
    for j in range(0, ki + 1):
        if known_root is True:
            aln_sc[0][j] = -1 * float("inf")
        else:
            aln_sc[0][j] = gap_penalty_j * j
    aln_sc[0][0] = 0  # The top left is fixed to zero
    ###    print(aln_sc)
    for i in range(1, kt + 1):
        for j in range(1, ki + 1):
            match = aln_sc[i - 1][j - 1] + sc_mat[i - 1, j - 1]
            gap_in_inferred = aln_sc[i - 1][j] + gap_penalty_i
            gap_in_true = aln_sc[i][j - 1] + gap_penalty_j
            aln_sc[i][j] = max(match, gap_in_inferred, gap_in_true)
    ###    print(aln_sc)
    # Traceback to compute the alignment:
    align_t, align_i, asr_align = list(), list(), list()
    i, j = kt, ki
    alignment_score = aln_sc[i][j]
    while i > 0 and j > 0:
        sc_current = aln_sc[i][j]
        sc_diagonal = aln_sc[i - 1][j - 1]
        sc_up = aln_sc[i][j - 1]
        sc_left = aln_sc[i - 1][j]

        if sc_current == (sc_diagonal + sc_mat[i - 1, j - 1]):
            align_t.append(lt[i - 1])
            align_i.append(li[j - 1])
            i -= 1
            j -= 1
        elif sc_current == (sc_left + gap_penalty_i):
            align_t.append(lt[i - 1])
            align_i.append("-")
            i -= 1
        elif sc_current == (sc_up + gap_penalty_j):
            align_t.append("-")
            align_i.append(li[j - 1])
            j -= 1

    # If space left fill it with gaps:
    while i > 0:
        asr_align.append(gap_penalty_i)
        align_t.append(lt[i - 1])
        align_i.append("-")
        i -= 1
    while j > 0:
        asr_align.append(gap_penalty_j)
        align_t.append("-")
        align_i.append(li[j - 1])
        j -= 1

    max_penalty = 0
    for a, b in zip(align_t, align_i):
        if a == "-" or b == "-":
            max_penalty += gap_penalty
        else:
            max_penalty += -len(a)
    # Notice that the root and the terminal node is excluded from this comparison.
    # by adding their length to the max_penalty:
    if known_root is True:
        max_penalty += 2 * len(lt[0])
    else:  # Or in the case of an unknown root, just add the terminal node
        max_penalty += len(lt[0])

    return [align_t, align_i, alignment_score, max_penalty]
Example #4
0
def main():
    parser = argparse.ArgumentParser(
        description="summary statistics of pre-tree data")
    parser.add_argument("input",
                        type=str,
                        nargs="+",
                        help="simulated fasta files")
    parser.add_argument("--experimental",
                        type=str,
                        nargs="*",
                        help="experimental fasta files")
    parser.add_argument("--outbase", type=str, help="output file base name")
    parser.add_argument("--root_idexp",
                        type=str,
                        default="root0",
                        help="root sequence ID")
    args = parser.parse_args()

    # simulations
    root_id = "root"
    for i, fname in enumerate(args.input):
        print(fname)
        seqs = {seq.id: str(seq.seq) for seq in fasta_parse(fname, "root")[0]}
        nseqs = len(seqs)
        if nseqs <= 2:
            continue

        distance_from_root, degree = zip(*[(
            hamming_distance(seqs[seqid], seqs[root_id]),
            min(
                hamming_distance(seqs[seqid], seqs[seqid2]) for seqid2 in seqs
                if seqid2 != root_id and seqid2 != seqid),
        ) for seqid in seqs if seqid != root_id])
        df = pd.DataFrame({
            "distance to root sequence": distance_from_root,
            "nearest neighbor distance": degree,
        })
        df["data set"] = i + 1
        if i == 0:
            aggdat = df
        else:
            aggdat = aggdat.append(df, ignore_index=True)

    ndatasets = len(set(aggdat["data set"]))

    # experimental
    if args.experimental is not None:
        for i, fname in enumerate(args.experimental):
            print(fname)
            seqs = {
                seq.id: str(seq.seq)
                for seq in fasta_parse(fname, args.root_idexp)[0]
            }
            nseqs = len(seqs)
            if nseqs <= 2:
                continue

            distance_from_root, degree = zip(*[(
                hamming_distance(seqs[seqid], seqs[args.root_idexp]),
                min(
                    hamming_distance(seqs[seqid], seqs[seqid2])
                    for seqid2 in seqs
                    if seqid2 != args.root_idexp and seqid2 != seqid),
            ) for seqid in seqs if seqid != args.root_idexp])
            df = pd.DataFrame({
                "distance to root sequence": distance_from_root,
                "nearest neighbor distance": degree,
            })
            df["data set"] = i + 1
            if i == 0:
                aggdat_exp = df
            else:
                aggdat_exp = aggdat_exp.append(df, ignore_index=True)

        ndatasets += len(set(aggdat_exp["data set"]))

    # bw = .3
    alpha = min([0.9, 20 / ndatasets])
    bins = range(
        max(
            aggdat["distance to root sequence"].max(),
            aggdat_exp["distance to root sequence"].max() if args.
            experimental is not None else 0,
        ) + 2)

    plt.figure(figsize=(6, 3))
    plt.subplot(1, 2, 1)
    ct = 0
    for dataset, dataset_aggdat in aggdat.groupby("data set"):
        ct += 1
        sns.distplot(
            dataset_aggdat["distance to root sequence"],
            bins=bins,
            kde=False,
            color="gray",
            hist_kws={
                "histtype": "step",
                "cumulative": True,
                "alpha": alpha,
                "lw": 1
            },
        )

    if args.experimental is not None:
        for dataset, dataset_aggdat in aggdat_exp.groupby("data set"):
            ct += 1
            sns.distplot(
                dataset_aggdat["distance to root sequence"],
                bins=bins,
                kde=False,
                color="black",
                hist_kws={
                    "histtype": "step",
                    "cumulative": True,
                    "alpha": 0.5,
                    "lw": 3,
                },
            )
    plt.xlabel("distance to root sequence")
    plt.xlim([0, bins[-1]])
    plt.ylabel("observed sequences")
    plt.tight_layout()

    bins = range(
        max(
            aggdat["nearest neighbor distance"].max(),
            aggdat_exp["nearest neighbor distance"].max() if args.
            experimental is not None else 0,
        ) + 2)

    plt.subplot(1, 2, 2)
    ct = 0
    for dataset, dataset_aggdat in aggdat.groupby("data set"):
        ct += 1
        sns.distplot(
            dataset_aggdat["nearest neighbor distance"],
            bins=bins,
            kde=False,
            color="gray",
            hist_kws={
                "histtype": "step",
                "cumulative": True,
                "alpha": alpha,
                "lw": 1
            },
        )
    if args.experimental is not None:
        for dataset, dataset_aggdat in aggdat_exp.groupby("data set"):
            ct += 1
            sns.distplot(
                dataset_aggdat["nearest neighbor distance"],
                bins=bins,
                kde=False,
                color="black",
                hist_kws={
                    "histtype": "step",
                    "cumulative": True,
                    "alpha": 0.5,
                    "lw": 3,
                },
            )
    plt.xlabel("nearest neighbor distance")
    plt.xlim([0, bins[-1]])
    plt.ylabel("")
    plt.tight_layout()

    plt.savefig(args.outbase + ".pdf")
Example #5
0
        aggdat = df
    else:
        aggdat = aggdat.append(df, ignore_index=True)

sims = set(aggdat["simulation"])
nsims = len(sims)

if args.experimental is not None:
    new_aln, counts = fasta_parse(args.experimental,
                                  root="GL",
                                  id_abundances=True)[:2]
    exp_dict = {seq.id: str(seq.seq) for seq in new_aln}
    root_id = [seq for seq in exp_dict if "gl" in seq][0]
    abundance, distance_from_root, degree = zip(*[(
        counts[seq],
        hamming_distance(exp_dict[seq], exp_dict[root_id]),
        sum(
            hamming_distance(exp_dict[seq], exp_dict[seq2]) == 1
            for seq2 in exp_dict if seq2 is not seq and counts[seq2] != 0),
    ) for seq in exp_dict if counts[seq] != 0])
    exp_stats = pd.DataFrame({
        "genotype abundance": abundance,
        "Hamming distance to root genotype": distance_from_root,
        "Hamming neighbor genotypes": degree,
    })

# bw = .3
alpha = min([0.9, 20 / nsims])
bins = range(
    max(
        aggdat["Hamming distance to root genotype"].max(),
Example #6
0
def simulate(args):
    """Simulation subprogram.

    Simulates a Galton–Watson process, with mutation probabilities
    according to a user defined motif model e.g. S5F
    """
    random.seed(a=args.seed)
    mutation_model = mm.MutationModel(args.mutability, args.substitution)
    if args.lambda0 is None:
        args.lambda0 = [max([1, int(0.01 * len(args.sequence))])]
    args.sequence = args.sequence.upper()
    if args.sequence2 is not None:
        # Use the same mutation rate on both sequences
        if len(args.lambda0) == 1:
            args.lambda0 = [args.lambda0[0], args.lambda0[0]]
        elif len(args.lambda0) != 2:
            raise Exception("Only one or two lambda0 can be defined for a two "
                            "sequence simulation.")
        # Require both sequences to be in frame 1:
        if args.frame is not None and args.frame != 1:
            if args.verbose:
                print("Warning: When simulating with two sequences they are "
                      "truncated to be beginning at frame 1.")
            args.sequence = args.sequence[(args.frame -
                                           1):(args.frame - 1 +
                                               (3 *
                                                (((len(args.sequence) -
                                                   (args.frame - 1)) // 3))))]
            args.sequence2 = args.sequence2[(args.frame - 1):(
                args.frame - 1 + (3 * (((len(args.sequence2) -
                                         (args.frame - 1)) // 3))))]
        # Extract the bounds between sequence 1 and 2:
        seq_bounds = (
            (0, len(args.sequence)),
            (len(args.sequence), len(args.sequence) + len(args.sequence2)),
        )
        # Merge the two seqeunces to simplify future dealing with the pair:
        args.sequence += args.sequence2
    else:
        seq_bounds = None

    trials = 1000
    # this loop makes us resimulate if size too small, or backmutation
    for trial in range(trials):
        try:
            tree = mutation_model.simulate(
                args.sequence,
                seq_bounds=seq_bounds,
                progeny=lambda seq: args.lambda_,
                lambda0=args.lambda0,
                n=args.n,
                N=args.N,
                T=args.T,
                frame=args.frame,
                verbose=args.verbose,
            )
            # this will fail if backmutations
            collapsed_tree = bp.CollapsedTree(tree=tree)
            tree.ladderize()
            uniques = sum(node.abundance > 0
                          for node in collapsed_tree.tree.traverse())
            if uniques < 2:
                raise RuntimeError(f"collapsed tree contains {uniques} "
                                   "sampled sequences")
            break
        except RuntimeError as e:
            print(f"{e}, trying again")
        else:
            raise
    if trial == trials - 1:
        raise RuntimeError(f"{trials} attempts exceeded")

    # In the case of a sequence pair print them to separate files:
    if args.sequence2 is not None:
        fh1 = open(args.outbase + ".simulation_seq1.fasta", "w")
        fh2 = open(args.outbase + ".simulation_seq2.fasta", "w")
        fh1.write(">root\n")
        fh1.write(args.sequence[seq_bounds[0][0]:seq_bounds[0][1]] + "\n")
        fh2.write(">root\n")
        fh2.write(args.sequence[seq_bounds[1][0]:seq_bounds[1][1]] + "\n")
        for leaf in tree.iter_leaves():
            if leaf.abundance != 0:
                fh1.write(">" + leaf.name + "\n")
                fh1.write(leaf.sequence[seq_bounds[0][0]:seq_bounds[0][1]] +
                          "\n")
                fh2.write(">" + leaf.name + "\n")
                fh2.write(leaf.sequence[seq_bounds[1][0]:seq_bounds[1][1]] +
                          "\n")
    else:
        with open(args.outbase + ".simulation.fasta", "w") as f:
            f.write(">root\n")
            f.write(args.sequence + "\n")
            for leaf in tree.iter_leaves():
                if leaf.abundance != 0:
                    f.write(">" + leaf.name + "\n")
                    f.write(leaf.sequence + "\n")

    # some observable simulation stats to write
    abundance, distance_from_root, degree = zip(*[(
        node.abundance,
        utils.hamming_distance(node.sequence, args.sequence),
        sum(
            utils.hamming_distance(node.sequence, node2.sequence) == 1
            for node2 in collapsed_tree.tree.traverse()
            if node2.abundance and node2 is not node),
    ) for node in collapsed_tree.tree.traverse() if node.abundance])
    stats = pd.DataFrame({
        "genotype abundance": abundance,
        "Hamming distance to root genotype": distance_from_root,
        "Hamming neighbor genotypes": degree,
    })
    stats.to_csv(args.outbase + ".simulation.stats.tsv", sep="\t", index=False)

    print(f"{sum(leaf.abundance for leaf in collapsed_tree.tree.traverse())}"
          " simulated observed sequences")

    # render the full lineage tree
    ts = ete3.TreeStyle()
    ts.rotation = 90
    ts.show_leaf_name = False
    ts.show_scale = False

    colors = {}
    palette = ete3.SVG_COLORS
    palette -= set(["black", "white", "gray"])
    palette = itertools.cycle(list(palette))  # <-- circular iterator

    colors[tree.sequence] = "gray"

    for n in tree.traverse():
        nstyle = ete3.NodeStyle()
        nstyle["size"] = 10
        if args.plotAA:
            if n.AAseq not in colors:
                colors[n.AAseq] = next(palette)
            nstyle["fgcolor"] = colors[n.AAseq]
        else:
            if n.sequence not in colors:
                colors[n.sequence] = next(palette)
            nstyle["fgcolor"] = colors[n.sequence]
        n.set_style(nstyle)

    # this makes the rendered branch lenths correspond to time
    for node in tree.iter_descendants():
        node.dist = node.time - node.up.time
    tree.render(args.outbase + ".simulation.lineage_tree.svg", tree_style=ts)

    # render collapsed tree
    # create an id-wise colormap
    # NOTE: node.name can be a set
    colormap = {
        node.name: colors[node.sequence]
        for node in collapsed_tree.tree.traverse()
    }
    collapsed_tree.write(args.outbase + ".simulation.collapsed_tree.p")
    collapsed_tree.render(
        args.outbase + ".simulation.collapsed_tree.svg",
        idlabel=args.idlabel,
        colormap=colormap,
        frame=args.frame,
    )
    # print colormap to file
    with open(args.outbase + ".simulation.collapsed_tree.colormap.tsv",
              "w") as f:
        for name, color in colormap.items():
            f.write((name if isinstance(name, str) else ",".join(name)) +
                    "\t" + color + "\n")
Example #7
0
    def __init__(self,
                 tree: ete3.TreeNode = None,
                 allow_repeats: bool = False):
        if tree is not None:
            self.tree = tree.copy()
            # remove unobserved internal unifurcations
            for node in self.tree.iter_descendants():
                parent = node.up
                if node.abundance == 0 and len(node.children) == 1:
                    node.delete(prevent_nondicotomic=False)
                    node.children[0].dist = utils.hamming_distance(
                        node.children[0].sequence, parent.sequence)

            # iterate over the tree below root and collapse edges of zero
            # length if the node is a leaf and it's parent has nonzero
            # abundance we combine taxa names to a set to acommodate
            # bootstrap samples that result in repeated genotypes
            observed_genotypes = set((leaf.name for leaf in self.tree))
            observed_genotypes.add(self.tree.name)
            for node in self.tree.get_descendants(strategy="postorder"):
                if node.dist == 0:
                    node.up.abundance += node.abundance
                    if isinstance(node.name, str):
                        node_set = set([node.name])
                    else:
                        node_set = set(node.name)
                    if isinstance(node.up.name, str):
                        node_up_set = set([node.up.name])
                    else:
                        node_up_set = set(node.up.name)
                    if node_up_set < observed_genotypes:
                        if node_set < observed_genotypes:
                            node.up.name = tuple(node_set | node_up_set)
                            if len(node.up.name) == 1:
                                node.up.name = node.up.name[0]
                    elif node_set < observed_genotypes:
                        node.up.name = tuple(node_set)
                        if len(node.up.name) == 1:
                            node.up.name = node.up.name[0]
                    node.delete(prevent_nondicotomic=False)

            final_observed_genotypes = set()
            for node in self.tree.traverse():
                if node.abundance > 0 or node == self.tree:
                    for name in ((node.name, )
                                 if isinstance(node.name, str) else node.name):
                        final_observed_genotypes.add(name)
            if final_observed_genotypes != observed_genotypes:
                raise RuntimeError(
                    "observed genotypes don't match after "
                    f"collapse\n\tbefore: {observed_genotypes}"
                    f"\n\tafter: {final_observed_genotypes}\n\t"
                    "symmetric diff: "
                    f"{observed_genotypes ^ final_observed_genotypes}")
            assert sum(node.abundance for node in tree.traverse()) == sum(
                node.abundance for node in self.tree.traverse())

            rep_seq = sum(
                node.abundance > 0 for node in self.tree.traverse()) - len(
                    set([
                        node.sequence
                        for node in self.tree.traverse() if node.abundance > 0
                    ]))
            if not allow_repeats and rep_seq:
                raise RuntimeError(
                    "Repeated observed sequences in collapsed "
                    f"tree. {rep_seq} sequences were found repeated.")
            elif allow_repeats and rep_seq:
                rep_seq = sum(node.abundance > 0
                              for node in self.tree.traverse()) - len(
                                  set([
                                      node.sequence
                                      for node in self.tree.traverse()
                                      if node.abundance > 0
                                  ]))
                print("Repeated observed sequences in collapsed tree. "
                      f"{rep_seq} sequences were found repeated.")
            # a custom ladderize accounting for abundance and sequence to break
            # ties in abundance
            for node in self.tree.traverse(strategy="postorder"):
                # add a partition feature and compute it recursively up tree
                node.add_feature(
                    "partition",
                    node.abundance + sum(node2.partition
                                         for node2 in node.children),
                )
                # sort children of this node based on partion and sequence
                node.children.sort(
                    key=lambda node: (node.partition, node.sequence))

            # create list of (c, m) for each node
            self._cm_list = [(node.abundance, len(node.children))
                             for node in self.tree.traverse()]
            # store max c and m
            self._c_max = max(node.abundance for node in self.tree.traverse())
            self._m_max = max(
                len(node.children) for node in self.tree.traverse())
        else:
            self.tree = tree
Example #8
0
    def simulate(
        self,
        sequence: str,
        seq_bounds: Tuple[Tuple[int, int], Tuple[int, int]] = None,
        fitness_function: Callable = lambda seq: 0.9,
        lambda0: List[np.float64] = [1],
        frame: int = None,
        N_init: int = 1,
        N: int = None,
        T: int = None,
        n: int = None,
        verbose: bool = False,
    ) -> TreeNode:
        r"""Simulate a neutral binary branching process with the mutation model, returning a :class:`ete3.Treenode` object.

        Args:
            sequence: root nucleotide sequence
            seq_bounds: ranges for two subsequences used as two parallel genes
            fitness_function: mean number offspring as a function of sequence
            lambda0: baseline mutation rate(s)
            frame: coding frame of starting position(s)
            N_init: initial naive abundnace
            N: maximum population size
            T: maximum generation time
            n: sample size
            verbose: print more messages
        """
        # Checking the validity of the input parameters:
        if N is not None and T is not None:
            raise ValueError(
                "Only one of N and T can be used. One must be None.")
        elif N is None and T is None:
            raise ValueError("Either N or T must be specified.")
        if N is not None and n is not None and n > N:
            raise ValueError("n ({}) must not larger than N ({})".format(n, N))

        # Planting the tree:
        tree = TreeNode()
        tree.dist = 0
        tree.add_feature("sequence", sequence)
        tree.add_feature("terminated", False)
        tree.add_feature("abundance", 0)
        tree.add_feature("time", 0)
        # add fitness attribute, interpreted as mean of offspring distribution
        tree.add_feature("fitness", fitness_function(tree.sequence))

        if N_init > 1:
            for _ in range(N_init):
                child = TreeNode()
                child.dist = 0
                child.add_feature("sequence", sequence)
                child.add_feature("abundance", 0)
                child.add_feature("terminated", False)
                child.add_feature("time", 0)
                # add fitness attribute, interpreted as mean of offspring distribution
                child.add_feature("fitness", fitness_function(child.sequence))
                tree.add_child(child)

        t = 0  # <-- time
        leaves_unterminated = N_init
        while (leaves_unterminated > 0
               and (leaves_unterminated < N if N is not None else True)
               and (t < max(T) if T is not None else True)):
            if verbose:
                print("At time:", t)
            t += 1
            list_of_leaves = list(tree.iter_leaves())
            random.shuffle(list_of_leaves)
            for leaf in list_of_leaves:
                # add fitness attribute, interpreted as mean of offspring distribution
                leaf.add_feature("fitness", fitness_function(leaf.sequence))
                if not leaf.terminated:
                    n_children = poisson(leaf.fitness).rvs()
                    leaves_unterminated += (
                        n_children - 1
                    )  # <-- this kills the parent if we drew a zero
                    if not n_children:
                        leaf.terminated = True
                    for child_count in range(n_children):
                        # If sequence pair mutate them separately with their own mutation rate:
                        if seq_bounds is not None:
                            mutated_sequence1 = self.mutate(
                                leaf.
                                sequence[seq_bounds[0][0]:seq_bounds[0][1]],
                                lambda0=lambda0[0],
                                frame=frame,
                            )
                            mutated_sequence2 = self.mutate(
                                leaf.
                                sequence[seq_bounds[1][0]:seq_bounds[1][1]],
                                lambda0=lambda0[1],
                                frame=frame,
                            )
                            mutated_sequence = mutated_sequence1 + mutated_sequence2
                        else:
                            mutated_sequence = self.mutate(leaf.sequence,
                                                           lambda0=lambda0[0],
                                                           frame=frame)
                        child = TreeNode()
                        child.dist = utils.hamming_distance(
                            mutated_sequence, leaf.sequence)
                        child.add_feature("sequence", mutated_sequence)
                        child.add_feature("abundance", 0)
                        child.add_feature("terminated", False)
                        child.add_feature("time", t)
                        leaf.add_child(child)

        if N is not None and leaves_unterminated < N:
            raise RuntimeError(
                "tree terminated with {} leaves, {} desired".format(
                    leaves_unterminated, N))

        # each leaf in final generation gets an observed abundance of 1, unless downsampled
        if T is not None and len(T) > 1:
            # Iterate the intermediate time steps:
            for Ti in sorted(T)[:-1]:
                # Only sample those that have been 'sampled' at intermediate sampling times:
                final_leaves = [
                    leaf for leaf in tree.iter_descendants()
                    if leaf.time == Ti and leaf.sampled
                ]
                if len(final_leaves) < n:
                    raise RuntimeError(
                        "tree terminated with {} leaves, less than what desired after downsampling {}"
                        .format(leaves_unterminated, n))
                for (leaf) in (
                        final_leaves
                ):  # No need to down-sample, this was already done in the simulation loop
                    leaf.abundance = 1
        # Do the normal sampling of the last time step:
        final_leaves = [leaf for leaf in tree.iter_leaves() if leaf.time == t]
        # by default, downsample to the target simulation size
        if n is not None and len(final_leaves) >= n:
            for leaf in random.sample(final_leaves, n):
                leaf.abundance = 1
        elif n is None and N is not None:
            for leaf in random.sample(final_leaves, N):
                leaf.abundance = 1
        elif N is None and T is not None:
            for leaf in final_leaves:
                leaf.abundance = 1
        elif n is not None and len(final_leaves) < n:
            raise RuntimeError(
                "tree terminated with {} leaves, less than what desired after downsampling {}"
                .format(leaves_unterminated, n))
        else:
            raise RuntimeError("Unknown option.")

        # prune away lineages that are unobserved
        for node in tree.iter_descendants():
            if sum(node2.abundance for node2 in node.traverse()) == 0:
                node.detach()

        # # remove unobserved unifurcations
        # for node in tree.iter_descendants():
        #     parent = node.up
        #     if node.abundance == 0 and len(node.children) == 1:
        #         node.delete(prevent_nondicotomic=False)
        #         node.children[0].dist = hamming_distance(node.children[0].sequence, parent.sequence)

        # assign unique names to each node
        for i, node in enumerate(tree.traverse(), 1):
            node.name = "simcell_{}".format(i)

        # return the uncollapsed tree
        return tree