Exemplo n.º 1
0
 def compare(self, tree2, method='identity'):
     '''compare this tree to the other tree'''
     if method == 'identity':
         # we compare lists of seq, parent, abundance
         # return true if these lists are identical, else false
         list1 = sorted((node.sequence, node.frequency,
                         node.up.sequence if node.up is not None else None)
                        for node in self.tree.traverse())
         list2 = sorted((node.sequence, node.frequency,
                         node.up.sequence if node.up is not None else None)
                        for node in tree2.tree.traverse())
         return list1 == list2
     elif method == 'MRCA':
         # matrix of hamming distance of common ancestors of taxa
         # takes a true and inferred tree as CollapsedTree objects
         taxa = [
             node.sequence for node in self.tree.traverse()
             if node.frequency
         ]
         n_taxa = len(taxa)
         d = scipy.zeros(shape=(n_taxa, n_taxa))
         sum_sites = scipy.zeros(shape=(n_taxa, n_taxa))
         for i in range(n_taxa):
             nodei_true = self.tree.iter_search_nodes(
                 sequence=taxa[i]).next()
             nodei = tree2.tree.iter_search_nodes(sequence=taxa[i]).next()
             for j in range(i + 1, n_taxa):
                 nodej_true = self.tree.iter_search_nodes(
                     sequence=taxa[j]).next()
                 nodej = tree2.tree.iter_search_nodes(
                     sequence=taxa[j]).next()
                 MRCA_true = self.tree.get_common_ancestor(
                     (nodei_true, nodej_true)).sequence
                 MRCA = tree2.tree.get_common_ancestor(
                     (nodei, nodej)).sequence
                 d[i, j] = hamming_distance(MRCA_true, MRCA)
                 sum_sites[i, j] = len(MRCA_true)
         return d.sum() / sum_sites.sum()
     elif method == 'RF':
         tree1_copy = self.tree.copy(method='deepcopy')
         tree2_copy = tree2.tree.copy(method='deepcopy')
         for treex in (tree1_copy, tree2_copy):
             for node in list(treex.traverse()):
                 if node.frequency > 0:
                     child = TreeNode()
                     child.add_feature('sequence', node.sequence)
                     node.add_child(child)
         try:
             return tree1_copy.robinson_foulds(tree2_copy,
                                               attr_t1='sequence',
                                               attr_t2='sequence',
                                               unrooted_trees=True)[0]
         except:
             return tree1_copy.robinson_foulds(tree2_copy,
                                               attr_t1='sequence',
                                               attr_t2='sequence',
                                               unrooted_trees=True,
                                               allow_dup=True)[0]
     else:
         raise ValueError('invalid distance method: ' + method)
Exemplo n.º 2
0
def addDeadLineage(spTree):
    """ 
    Takes:
        - spTree (ete3.Tree) : species tree

    Returns:
        (ete3.Tree) : same tree with a dead lineage (name "-1") as outgroup
                      AND all nodes have a "dead" feature (bool that is True only for the dead lineage and the new root)
    """

    newSpTree = deepcopy(spTree)
    newSpTree.dist = 0.1

    for n in newSpTree.traverse():
        n.add_feature("dead",False)

    newRoot = TreeNode()
    newRoot.add_feature("dead",True)
    newRoot.dist = 0.0


    newRoot.add_child(newSpTree)
    rootHeight = newRoot.get_distance(newRoot.get_leaves()[0])

    deadLineage = TreeNode()
    deadLineage.add_feature("dead",True)    
    deadLineage.name = "-1"
    deadLineage.dist  = rootHeight

    newRoot.add_child(deadLineage)

    return newRoot
Exemplo n.º 3
0
        def _build(parse_node, tree_node=None):
            if tree_node is None:
                tree_node = TreeNode()

            if isinstance(parse_node, list):
                print(parse_node)

            if isinstance(parse_node, LeafNode):
                symbol = parse_node.literal
                token = parse_node.token

                tree_node.name = symbol
                tree_node.add_feature("tokens", [token])
            elif isinstance(parse_node, InternalNode):
                symbol = parse_node.symbol
                rule = parse_node.rule
                children = parse_node.children
                token = []

                for child_node in children:
                    node = _build(child_node)
                    tree_node.add_child(node)
                    token.extend(node.tokens)

                tree_node.name = symbol
                tree_node.add_feature("rule", rule)
                tree_node.add_feature("tokens", token)

            return tree_node
Exemplo n.º 4
0
def subdivideSpTree(spTree):
    """
    Takes:
        - spTree (ete3.Tree) : an ULTRAMETRIC species tree

    Returns:
        (ete3.Tree) : subdivided species tree where all nodes have a timeSlice feature
        or
        None if the species tree is not ultrametric
    """
    newSpTree = deepcopy(spTree)

    featureName = "timeSlice"

    ##1/ getting distance from root.
    Dheight = getDistFromRootDic(newSpTree , checkUltrametric = True)

    if Dheight is None:
        print "!!ERROR!! : the species tree is not ultrametric"
        return None

    # we know that there is n-1 internal nodes (where n is the number of leaves)
    # hence the maximal timeSlice is n-1 (all leaves have timeSlice 0)

    ##2/assign timeSlice to nodes
    currentTS = len(newSpTree.get_leaves()) - 1


    for n,h in sorted(Dheight.iteritems(), key=lambda (k,v): (v,k)):
        n.add_feature(featureName, currentTS )

        if currentTS != 0:
            currentTS -= 1


    #print newSpTree.get_ascii(attributes=[featureName,"name"])

    ##3/subdivide according to timeSlice
    RealNodes = [i for i in  newSpTree.traverse()]

    for n in RealNodes:
        if n.is_root():
            continue

        nodeToAdd = n.up.timeSlice - n.timeSlice - 1

        while nodeToAdd > 0:
            parentNode = n.up
            
            n.detach()
            
            NullNode = TreeNode()
            NullNode.add_feature( featureName, parentNode.timeSlice - 1 )

            if "dead" in n.features:
                NullNode.add_feature("dead" , n.dead)

            parentNode.add_child(NullNode)
            NullNode.add_child(n)
            nodeToAdd -= 1 

    #print newSpTree.get_ascii(attributes=[featureName,"name"])
    return newSpTree
Exemplo n.º 5
0
    def simulate(self,
                 sequence,
                 pair_bounds=None,
                 lambda_=0.9,
                 lambda0=[1],
                 N=None,
                 T=None,
                 n=None,
                 verbose=False,
                 selection_params=None):
        '''
        Simulate a poisson branching process with mutation introduced
        by the chosen mutation model e.g. motif or uniform.
        Can either simulate under a neutral model without selection,
        or using an affinity muturation inspired model for selection.
        '''
        progeny = poisson(lambda_)  # Default progeny distribution
        stop_dist = None  # Default stopping criterium for affinity simulation
        # Checking the validity of the input parameters:
        if N is not None and T is not None:
            raise ValueError(
                'Only one of N and T can be used. One must be None.')
        if selection_params is not None and T is None:
            raise ValueError(
                'Simulation with selection was chosen. A time, T, must be specified.'
            )
        elif N is None and T is None:
            raise ValueError('Either N or T must be specified.')
        if N is not None and n is not None and n[-1] > N:
            raise ValueError('n ({}) must not larger than N ({})'.format(
                n[-1], N))
        elif N is not None and n is not None and len(n) != 1:
            raise ValueError(
                'n ({}) must a single value when specifying N'.format(n))
        if T is not None and len(T) > 1 and (n is None or
                                             (len(n) != 1
                                              and len(n) != len(T))):
            raise ValueError(
                'n must be specified when using intermediate sampling:', n)
        elif T is not None and len(T) > 1 and len(n) == 1:
            n = [n[-1]] * len(T)

        # Planting the tree:
        tree = TreeNode()
        tree.dist = 0
        tree.add_feature('sequence', sequence)
        tree.add_feature('terminated', False)
        tree.add_feature('sampled', False)
        tree.add_feature('frequency', 0)
        tree.add_feature('time', 0)

        if selection_params is not None:
            hd_generation = list(
            )  # Collect an array of the counts of each hamming distance at each time step
            stop_dist, mature_affy, naive_affy, target_dist, target_count, skip_update, A_total, B_total, Lp, k, outbase = selection_params
            # Make a list of target sequences:
            targetAAseqs = [
                self.one_mutant(sequence, target_dist)
                for i in range(target_count)
            ]
            # Assert that the target sequences are comparable to the naive sequence:
            aa = translate(tree.sequence)
            assert (sum([1 for t in targetAAseqs if len(t) != len(aa)]) == 0
                    )  # All targets are same length
            assert (sum([
                1 for t in targetAAseqs
                if hamming_distance(aa, t) == target_dist
            ]))  # All target are "target_dist" away from the naive sequence
            # Affinity is an exponential function of hamming distance:
            assert (target_dist > 0)

            def hd2affy(hd):
                return (mature_affy + hd**k *
                        (naive_affy - mature_affy) / target_dist**k)

            # We store both the amino acid sequence and the affinity as tree features:
            tree.add_feature('AAseq', str(aa))
            tree.add_feature(
                'Kd', selection_utils.calc_Kd(tree.AAseq, targetAAseqs,
                                              hd2affy))
            tree.add_feature(
                'target_dist',
                min([
                    hamming_distance(tree.AAseq, taa) for taa in targetAAseqs
                ]))

        t = 0  # <-- Time at start
        leaves_unterminated = 1
        # Small lambdas are causing problems so make a minimum:
        lambda_min = 10e-10
        hd_distrib = []
        while leaves_unterminated > 0 and (
                leaves_unterminated < N if N is not None else
                True) and (t < max(T) if T is not None else True) and (
                    stop_dist >= min(hd_distrib)
                    if stop_dist is not None and t > 0 else True):
            if verbose:
                print('At time:', t)
            t += 1
            # Sample intermediate time point:
            if T is not None and len(T) > 1 and (t - 1) in T:
                si = T.index(t - 1)
                live_nostop_leaves = [
                    l for l in tree.iter_leaves()
                    if not l.terminated and not has_stop(l.sequence)
                ]
                random.shuffle(live_nostop_leaves)
                if len(live_nostop_leaves) < n[si]:
                    raise RuntimeError(
                        'tree with {} leaves, less than what desired for intermediate sampling {}. Try later generation or increasing the carrying capacity.'
                        .format(leaves_unterminated, n))
                # Make the sample and kill the cells sampled:
                for leaf in live_nostop_leaves[:n[si]]:
                    leaves_unterminated -= 1
                    leaf.sampled = True
                    leaf.terminated = True
                if verbose:
                    print('Made an intermediate sample at time:', t - 1)
            live_leaves = [l for l in tree.iter_leaves() if not l.terminated]
            random.shuffle(live_leaves)
            skip_lambda_n = 0  # At every new round reset the all the lambdas
            # Draw progeny for each leaf:
            for leaf in live_leaves:
                if selection_params is not None:
                    if skip_lambda_n == 0:
                        skip_lambda_n = skip_update + 1  # Add one so skip_update=0 is no skip
                        tree = selection_utils.lambda_selection(
                            tree, targetAAseqs, hd2affy, A_total, B_total, Lp)
                    if leaf.lambda_ > lambda_min:
                        progeny = poisson(leaf.lambda_)
                    else:
                        progeny = poisson(lambda_min)
                    skip_lambda_n -= 1
                n_children = progeny.rvs()
                leaves_unterminated += n_children - 1  # <-- Getting 1, is equal to staying alive
                if not n_children:
                    leaf.terminated = True
                for child_count in range(n_children):
                    # If sequence pair mutate them separately with their own mutation rate:
                    if pair_bounds is not None:
                        mutated_sequence1 = self.mutate(
                            leaf.sequence[pair_bounds[0][0]:pair_bounds[0][1]],
                            lambda0=lambda0[0])
                        mutated_sequence2 = self.mutate(
                            leaf.sequence[pair_bounds[1][0]:pair_bounds[1][1]],
                            lambda0=lambda0[1])
                        mutated_sequence = mutated_sequence1 + mutated_sequence2
                    else:
                        mutated_sequence = self.mutate(leaf.sequence,
                                                       lambda0=lambda0[0])
                    child = TreeNode()
                    child.dist = sum(
                        x != y
                        for x, y in zip(mutated_sequence, leaf.sequence))
                    child.add_feature('sequence', mutated_sequence)
                    if selection_params is not None:
                        aa = translate(child.sequence)
                        child.add_feature('AAseq', str(aa))
                        child.add_feature(
                            'Kd',
                            selection_utils.calc_Kd(child.AAseq, targetAAseqs,
                                                    hd2affy))
                        child.add_feature(
                            'target_dist',
                            min([
                                hamming_distance(child.AAseq, taa)
                                for taa in targetAAseqs
                            ]))
                    child.add_feature('frequency', 0)
                    child.add_feature('terminated', False)
                    child.add_feature('sampled', False)
                    child.add_feature('time', t)
                    leaf.add_child(child)
            if selection_params is not None:
                hd_distrib = [
                    min([
                        hamming_distance(tn.AAseq, ta) for ta in targetAAseqs
                    ]) for tn in tree.iter_leaves() if not tn.terminated
                ]
                if target_dist > 0:
                    hist = scipy.histogram(hd_distrib,
                                           bins=list(range(target_dist * 10)))
                else:  # Just make a minimum of 10 bins
                    hist = scipy.histogram(hd_distrib, bins=list(range(10)))
                hd_generation.append(hist)
                if verbose and hd_distrib:
                    print('Total cell population:', sum(hist[0]))
                    print('Majority hamming distance:', scipy.argmax(hist[0]))
                    print('Affinity of latest sampled leaf:', leaf.Kd)
                    print(
                        'Progeny distribution lambda for the latest sampled leaf:',
                        leaf.lambda_)

        if leaves_unterminated < N:
            raise RuntimeError(
                'Tree terminated with {} leaves, {} desired'.format(
                    leaves_unterminated, N))

        # Keep a histogram of the hamming distances at each generation:
        if selection_params is not None:
            with open(outbase + '_selection_runstats.p', 'wb') as f:
                pickle.dump(hd_generation, f)

        # Each leaf in final generation gets an observation frequency of 1, unless downsampled:
        if T is not None and len(T) > 1:
            # Iterate the intermediate time steps (excluding the last time):
            for Ti in sorted(T)[:-1]:
                si = T.index(Ti)
                # Only sample those that have been 'sampled' at intermediate sampling times:
                final_leaves = [
                    leaf for leaf in tree.iter_descendants()
                    if leaf.time == Ti and leaf.sampled
                ]
                if len(final_leaves) < n[si]:
                    raise RuntimeError(
                        'tree terminated with {} leaves, less than what desired after downsampling {}'
                        .format(leaves_unterminated, n[si]))
                for leaf in final_leaves:  # No need to down-sample, this was already done in the simulation loop
                    leaf.frequency = 1
        if selection_params and max(T) != t:
            raise RuntimeError(
                'tree terminated with before the requested sample time.')
        # Do the normal sampling of the last time step:
        final_leaves = [
            leaf for leaf in tree.iter_leaves()
            if leaf.time == t and not has_stop(leaf.sequence)
        ]
        # Report stop codon sequences:
        stop_leaves = [
            leaf for leaf in tree.iter_leaves()
            if leaf.time == t and has_stop(leaf.sequence)
        ]
        if stop_leaves:
            print(
                'Tree contains {} leaves with stop codons, out of {} total at last time point.'
                .format(len(stop_leaves), len(final_leaves)))

        if T is not None:
            si = T.index(sorted(T)[-1])
        else:
            si = 0
        # By default, downsample to the target simulation size:
        if n is not None and len(final_leaves) >= n[si]:
            for leaf in random.sample(final_leaves, n[si]):
                leaf.frequency = 1
        elif n is None and N is not None:
            if len(
                    final_leaves
            ) < N:  # Removed nonsense sequences might decrease the number of final leaves to less than N
                N = len(final_leaves)
            for leaf in random.sample(final_leaves, N):
                leaf.frequency = 1
        elif N is None and T is not None:
            for leaf in final_leaves:
                leaf.frequency = 1
        elif n is not None and len(final_leaves) < n[si]:
            raise RuntimeError(
                'tree terminated with {} leaves, less than what desired after downsampling {}'
                .format(leaves_unterminated, n[si]))
        else:
            raise RuntimeError('Unknown option.')

        # Prune away lineages that are unobserved:
        for node in tree.iter_descendants():
            if sum(node2.frequency for node2 in node.traverse()) == 0:
                node.detach()

        # Remove unobserved unifurcations:
        for node in tree.iter_descendants():
            parent = node.up
            if node.frequency == 0 and len(node.children) == 1:
                node.delete(prevent_nondicotomic=False)
                node.children[0].dist = hamming_distance(
                    node.children[0].sequence, parent.sequence)

        # Assign unique names to each node:
        for i, node in enumerate(tree.traverse(), 1):
            node.name = 'simcell_{}'.format(i)

        # Return the uncollapsed tree:
        return tree
Exemplo n.º 6
0
def parse_nexus(tree_path, columns=None):
    trees = []
    for nex_tree in read_nexus(tree_path):
        todo = [(nex_tree.root, None)]
        tree = None
        while todo:
            clade, parent = todo.pop()
            dist = 0
            try:
                dist = float(clade.branch_length)
            except:
                pass
            name = getattr(clade, 'name', None)
            if not name:
                name = getattr(clade, 'confidence', None)
                if not isinstance(name, str):
                    name = None
            node = TreeNode(dist=dist, name=name)
            if parent is None:
                tree = node
            else:
                parent.add_child(node)

            # Parse LSD2 dates and CIs, and PastML columns
            date, ci = None, None
            columns2values = defaultdict(set)
            comment = getattr(clade, 'comment', None)
            if isinstance(comment, str):
                date = next(iter(re.findall(DATE_COMMENT_REGEX, comment)),
                            None)
                ci = next(iter(re.findall(CI_DATE_REGEX_LSD, comment)), None)
                if ci is None:
                    ci = next(iter(re.findall(CI_DATE_REGEX_PASTML, comment)),
                              None)
                if columns:
                    for column in columns:
                        values = \
                            set.union(*(set(_.split('|')) for _ in re.findall(COLUMN_REGEX_PASTML.format(column=column),
                                                                              comment)), set())
                        if values:
                            columns2values[column] |= values
            comment = getattr(clade, 'branch_length', None)
            if not ci and not parent and isinstance(comment, str):
                ci = next(iter(re.findall(CI_DATE_REGEX_LSD, comment)), None)
                if ci is None:
                    ci = next(iter(re.findall(CI_DATE_REGEX_PASTML, comment)),
                              None)
            comment = getattr(clade, 'confidence', None)
            if ci is None and comment is not None and isinstance(comment, str):
                ci = next(iter(re.findall(CI_DATE_REGEX_LSD, comment)), None)
                if ci is None:
                    ci = next(iter(re.findall(CI_DATE_REGEX_PASTML, comment)),
                              None)
            if date is not None:
                try:
                    date = float(date)
                    node.add_feature(DATE, date)
                except:
                    pass
            if ci is not None:
                try:
                    ci = [float(_) for _ in ci]
                    node.add_feature(DATE_CI, ci)
                except:
                    pass
            if columns2values:
                for c, vs in columns2values.items():
                    node.add_feature(c, vs)
            todo.extend((c, node) for c in clade.clades)
        for n in tree.traverse('preorder'):
            date, ci = getattr(n, DATE, None), getattr(n, DATE_CI, None)
            if date is not None or ci is not None:
                for c in n.children:
                    if c.dist == 0:
                        if getattr(c, DATE, None) is None:
                            c.add_feature(DATE, date)
                        if getattr(c, DATE_CI, None) is None:
                            c.add_feature(DATE_CI, ci)
        for n in tree.traverse('postorder'):
            date, ci = getattr(n, DATE, None), getattr(n, DATE_CI, None)
            if not n.is_root() and n.dist == 0 and (date is not None
                                                    or ci is not None):
                if getattr(n.up, DATE, None) is None:
                    n.up.add_feature(DATE, date)
                if getattr(n.up, DATE_CI, None) is None:
                    n.up.add_feature(DATE_CI, ci)

        # propagate dates up to the root if needed
        if getattr(tree, DATE, None) is None:
            dated_node = next((n for n in tree.traverse()
                               if getattr(n, DATE, None) is not None), None)
            if dated_node:
                while dated_node != tree:
                    if getattr(dated_node.up, DATE, None) is None:
                        dated_node.up.add_feature(
                            DATE,
                            getattr(dated_node, DATE) - dated_node.dist)
                    dated_node = dated_node.up

        trees.append(tree)
    return trees
Exemplo n.º 7
0
        ShowFormat()
        sys.exit(-1)

    basehtml = args['--html'] if args['--html'] else 'base.html'

    from ete3 import Tree, TreeNode
    #read ped file from stdin.
    ped_data = {}  #map for name -> raw data.
    node_data = {}  #map for name -> TreeNode
    for line in sys.stdin:
        line = line.strip()
        if line and line[0] != '#':  #skip comment line.
            ss = line.split()
            ped_data[ss[1]] = ss
            n = TreeNode(name=ss[1])
            n.add_feature('raw', ss)
            node_data[ss[1]] = n

    # for k,v in node_data.items():
    #     print(v.write(format=2,features=['raw']))

    #find the root node, and convert results to josn.
    #Check data integrity.
    m_error = False
    for _, data in ped_data.items():
        if data[2] != '0' and data[2] not in ped_data.keys():
            m_error = True
            sys.stderr.write('ERROR: missing declearation for father: %s\n' %
                             (data[2]))
        if data[3] != '0' and data[3] not in ped_data.keys():
            m_error = True
Exemplo n.º 8
0
def create_tree_node(seq, frequency=0):
    tree = TreeNode()
    tree.add_feature('sequence', seq)
    tree.add_feature('frequency', frequency)

    return tree
Exemplo n.º 9
0
class CollapsedTree(LeavesAndClades):
    '''
    Here's a derived class for a collapsed tree, where we recurse into the mutant clades
          (4)
         / | \\
       (3)(1)(2)
           |   \\
          (2)  (1)
    '''
    def __init__(self,
                 params=None,
                 tree=None,
                 frame=None,
                 collapse_syn=False,
                 allow_repeats=False):
        '''
        For intialization, either params or tree (or both) must be provided
        params: offspring distribution parameters
        tree: ete tree with frequency node feature. If uncollapsed, it will be collapsed
        frame: tranlation frame, with default None, no tranlation attempted
        '''
        LeavesAndClades.__init__(self, params=params)
        if frame is not None and frame not in (1, 2, 3):
            raise RuntimeError('frame must be 1, 2, 3, or None')
        self.frame = frame

        if collapse_syn is True:
            tree.dist = 0  # no branch above root
            for node in tree.iter_descendants():
                aa = Seq(
                    node.sequence[(frame - 1):(frame - 1 +
                                               (3 * (((len(node.sequence) -
                                                       (frame - 1)) // 3))))],
                    generic_dna).translate()
                aa_parent = Seq(
                    node.up.sequence[(frame - 1):(frame - 1 +
                                                  (3 *
                                                   (((len(node.sequence) -
                                                      (frame - 1)) // 3))))],
                    generic_dna).translate()
                node.dist = hamming_distance(aa, aa_parent)

        if tree is not None:
            self.tree = tree.copy()
            # remove unobserved internal unifurcations
            for node in self.tree.iter_descendants():
                parent = node.up
                if node.frequency == 0 and len(node.children) == 1:
                    node.delete(prevent_nondicotomic=False)
                    node.children[0].dist = hamming_distance(
                        node.children[0].sequence, parent.sequence)

            # iterate over the tree below root and collapse edges of zero length
            # if the node is a leaf and it's parent has nonzero frequency we combine taxa names to a set
            # this acommodates bootstrap samples that result in repeated genotypes
            observed_genotypes = set((leaf.name for leaf in self.tree))
            observed_genotypes.add(self.tree.name)
            for node in self.tree.get_descendants(strategy='postorder'):
                if node.dist == 0:
                    node.up.frequency += node.frequency
                    node_set = set([node.name]) if isinstance(
                        node.name, str) else set(node.name)
                    node_up_set = set([node.up.name]) if isinstance(
                        node.up.name, str) else set(node.up.name)
                    if node_up_set < observed_genotypes:
                        if node_set < observed_genotypes:
                            node.up.name = tuple(node_set | node_up_set)
                            if len(node.up.name) == 1:
                                node.up.name = node.up.name[0]
                    elif node_set < observed_genotypes:
                        node.up.name = tuple(node_set)
                        if len(node.up.name) == 1:
                            node.up.name = node.up.name[0]
                    node.delete(prevent_nondicotomic=False)

            final_observed_genotypes = set([
                name for node in self.tree.traverse()
                if node.frequency > 0 or node == self.tree for name in ((
                    node.name, ) if isinstance(node.name, str) else node.name)
            ])
            if final_observed_genotypes != observed_genotypes:
                raise RuntimeError(
                    'observed genotypes don\'t match after collapse\n\tbefore: {}\n\tafter: {}\n\tsymmetric diff: {}'
                    .format(observed_genotypes, final_observed_genotypes,
                            observed_genotypes ^ final_observed_genotypes))
            assert sum(node.frequency for node in tree.traverse()) == sum(
                node.frequency for node in self.tree.traverse())

            rep_seq = sum(
                node.frequency > 0 for node in self.tree.traverse()) - len(
                    set([
                        node.sequence
                        for node in self.tree.traverse() if node.frequency > 0
                    ]))
            if not allow_repeats and rep_seq:
                raise RuntimeError(
                    'Repeated observed sequences in collapsed tree. {} sequences were found repeated.'
                    .format(rep_seq))
            elif allow_repeats and rep_seq:
                rep_seq = sum(node.frequency > 0
                              for node in self.tree.traverse()) - len(
                                  set([
                                      node.sequence
                                      for node in self.tree.traverse()
                                      if node.frequency > 0
                                  ]))
                print(
                    'Repeated observed sequences in collapsed tree. {} sequences were found repeated.'
                    .format(rep_seq))
            # a custom ladderize accounting for abundance and sequence to break ties in abundance
            for node in self.tree.traverse(strategy='postorder'):
                # add a partition feature and compute it recursively up the tree
                node.add_feature(
                    'partition',
                    node.frequency + sum(node2.partition
                                         for node2 in node.children))
                # sort children of this node based on partion and sequence
                node.children.sort(
                    key=lambda node: (node.partition, node.sequence))
        else:
            self.tree = tree

    def l(self, params, sign=1):
        '''
        log likelihood of params, conditioned on collapsed tree, and its gradient wrt params
        optional parameter sign must be 1 or -1, with the latter useful for MLE by minimization
        '''
        if self.tree is None:
            raise ValueError('tree data must be defined to compute likelihood')
        if sign not in (-1, 1):
            raise ValueError('sign must be 1 or -1')
        leaves_and_clades_list = [
            LeavesAndClades(c=node.frequency, m=len(node.children))
            for node in self.tree.traverse()
        ]
        if leaves_and_clades_list[0].c == 0 and leaves_and_clades_list[
                0].m == 1 and leaves_and_clades_list[0].f(params)[0] == 0:
            # if unifurcation not possible under current model, add a psuedocount for the naive
            leaves_and_clades_list[0].c = 1
        # extract vector of function values and gradient components
        f_data = [
            leaves_and_clades.f(params)
            for leaves_and_clades in leaves_and_clades_list
        ]
        fs = scipy.array([[x[0]] for x in f_data])
        logf = scipy.log(fs).sum()
        grad_fs = scipy.array([x[1] for x in f_data])
        grad_logf = (grad_fs / fs).sum(axis=0)
        return sign * logf, sign * grad_logf

    def mle(self, **kwargs):
        '''
        Maximum likelihood estimate for params given tree
        updates params if not None
        returns optimization result
        '''
        # random initalization
        x_0 = (random.random(), random.random())
        bounds = ((.01, .99), (.001, .999))
        kwargs['sign'] = -1
        grad_check = check_grad(lambda x: self.l(x, **kwargs)[0],
                                lambda x: self.l(x, **kwargs)[1], (.4, .5))
        if grad_check > 1e-3:
            warnings.warn(
                'gradient mismatches finite difference approximation by {}'.
                format(grad_check), RuntimeWarning)
        result = minimize(lambda x: self.l(x, **kwargs),
                          x0=x_0,
                          jac=True,
                          method='L-BFGS-B',
                          options={'ftol': 1e-10},
                          bounds=bounds)
        # update params if None and optimization successful
        if not result.success:
            warnings.warn('optimization not sucessful, ' + result.message,
                          RuntimeWarning)
        elif self.params is None:
            self.params = result.x.tolist()
        return result

    def simulate(self):
        '''
        simulate a collapsed tree given params
        replaces existing tree data member with simulation result, and returns self
        '''
        if self.params is None:
            raise ValueError('params must be defined for simulation')

        # initiate by running a LeavesAndClades simulation to get the number of clones and mutants
        # in the root node of the collapsed tree
        LeavesAndClades.simulate(self)
        self.tree = TreeNode()
        self.tree.add_feature('frequency', self.c)
        if self.m == 0:
            return self
        for _ in range(self.m):
            # ooooh, recursion
            child = CollapsedTree(params=self.params,
                                  frame=self.frame).simulate().tree
            child.dist = 1
            self.tree.add_child(child)

        return self

    def __str__(self):
        '''return a string representation for printing'''
        return 'params = ' + str(self.params) + '\ntree:\n' + str(self.tree)

    def render(self,
               outfile,
               idlabel=False,
               colormap=None,
               show_support=False,
               chain_split=None):
        '''render to image file, filetype inferred from suffix, svg for color images'''
        def my_layout(node):
            circle_color = 'lightgray' if colormap is None or node.name not in colormap else colormap[
                node.name]
            text_color = 'black'
            if isinstance(circle_color, str):
                C = CircleFace(radius=max(3, 10 * scipy.sqrt(node.frequency)),
                               color=circle_color,
                               label={
                                   'text': str(node.frequency),
                                   'color': text_color
                               } if node.frequency > 0 else None)
                C.rotation = -90
                C.hz_align = 1
                faces.add_face_to_node(C, node, 0)
            else:
                P = PieChartFace(
                    [100 * x / node.frequency for x in circle_color.values()],
                    2 * 10 * scipy.sqrt(node.frequency),
                    2 * 10 * scipy.sqrt(node.frequency),
                    colors=[(color if color != 'None' else 'lightgray')
                            for color in list(circle_color.keys())],
                    line_color=None)
                T = TextFace(' '.join(
                    [str(x) for x in list(circle_color.values())]),
                             tight_text=True)
                T.hz_align = 1
                T.rotation = -90
                faces.add_face_to_node(P, node, 0, position='branch-right')
                faces.add_face_to_node(T, node, 1, position='branch-right')
            if idlabel:
                T = TextFace(node.name, tight_text=True, fsize=6)
                T.rotation = -90
                T.hz_align = 1
                faces.add_face_to_node(
                    T,
                    node,
                    1 if isinstance(circle_color, str) else 2,
                    position='branch-right')

        for node in self.tree.traverse():
            nstyle = NodeStyle()
            nstyle['size'] = 0
            if node.up is not None:
                if set(node.sequence.upper()) == set('ACGT'):
                    if chain_split is not None:
                        if self.frame is not None:
                            raise NotImplementedError(
                                'frame not implemented with chain_split')
                        leftseq_mutated = hamming_distance(
                            node.sequence[:chain_split],
                            node.up.sequence[:chain_split]) > 0
                        rightseq_mutated = hamming_distance(
                            node.sequence[chain_split:],
                            node.up.sequence[chain_split:]) > 0
                        if leftseq_mutated and rightseq_mutated:
                            nstyle['hz_line_color'] = 'purple'
                            nstyle['hz_line_width'] = 3
                        elif leftseq_mutated:
                            nstyle['hz_line_color'] = 'red'
                            nstyle['hz_line_width'] = 2
                        elif rightseq_mutated:
                            nstyle['hz_line_color'] = 'blue'
                            nstyle['hz_line_width'] = 2
                    if self.frame is not None:
                        aa = Seq(
                            node.sequence[(self.frame -
                                           1):(self.frame - 1 +
                                               (3 *
                                                (((len(node.sequence) -
                                                   (self.frame - 1)) // 3))))],
                            generic_dna).translate()
                        aa_parent = Seq(
                            node.up.sequence[(self.frame -
                                              1):(self.frame - 1 + (3 * ((
                                                  (len(node.sequence) -
                                                   (self.frame - 1)) // 3))))],
                            generic_dna).translate()
                        nonsyn = hamming_distance(aa, aa_parent)
                        if '*' in aa:
                            nstyle['bgcolor'] = 'red'
                        if nonsyn > 0:
                            nstyle['hz_line_color'] = 'black'
                            nstyle['hz_line_width'] = nonsyn
                        else:
                            nstyle['hz_line_type'] = 1
            node.set_style(nstyle)

        ts = TreeStyle()
        ts.show_leaf_name = False
        ts.rotation = 90
        ts.draw_aligned_faces_as_table = False
        ts.allow_face_overlap = True
        ts.layout_fn = my_layout
        ts.show_scale = False
        ts.show_branch_support = show_support
        self.tree.render(outfile, tree_style=ts)
        # if we labelled seqs, let's also write the alignment out so we have the sequences (including of internal nodes)
        if idlabel:
            aln = MultipleSeqAlignment([])
            for node in self.tree.traverse():
                aln.append(
                    SeqRecord(Seq(str(node.sequence), generic_dna),
                              id=str(node.name),
                              description='abundance={}'.format(
                                  node.frequency)))
            AlignIO.write(aln,
                          open(os.path.splitext(outfile)[0] + '.fasta', 'w'),
                          'fasta')

    def write(self, file_name):
        '''serialize tree to file'''
        with open(file_name, 'wb') as f:
            pickle.dump(self, f)

    def compare(self, tree2, method='identity'):
        '''compare this tree to the other tree'''
        if method == 'identity':
            # we compare lists of seq, parent, abundance
            # return true if these lists are identical, else false
            list1 = sorted((node.sequence, node.frequency,
                            node.up.sequence if node.up is not None else None)
                           for node in self.tree.traverse())
            list2 = sorted((node.sequence, node.frequency,
                            node.up.sequence if node.up is not None else None)
                           for node in tree2.tree.traverse())
            return list1 == list2
        elif method == 'MRCA':
            # matrix of hamming distance of common ancestors of taxa
            # takes a true and inferred tree as CollapsedTree objects
            taxa = [
                node.sequence for node in self.tree.traverse()
                if node.frequency
            ]
            n_taxa = len(taxa)
            d = scipy.zeros(shape=(n_taxa, n_taxa))
            sum_sites = scipy.zeros(shape=(n_taxa, n_taxa))
            for i in range(n_taxa):
                nodei_true = self.tree.iter_search_nodes(
                    sequence=taxa[i]).next()
                nodei = tree2.tree.iter_search_nodes(sequence=taxa[i]).next()
                for j in range(i + 1, n_taxa):
                    nodej_true = self.tree.iter_search_nodes(
                        sequence=taxa[j]).next()
                    nodej = tree2.tree.iter_search_nodes(
                        sequence=taxa[j]).next()
                    MRCA_true = self.tree.get_common_ancestor(
                        (nodei_true, nodej_true)).sequence
                    MRCA = tree2.tree.get_common_ancestor(
                        (nodei, nodej)).sequence
                    d[i, j] = hamming_distance(MRCA_true, MRCA)
                    sum_sites[i, j] = len(MRCA_true)
            return d.sum() / sum_sites.sum()
        elif method == 'RF':
            tree1_copy = self.tree.copy(method='deepcopy')
            tree2_copy = tree2.tree.copy(method='deepcopy')
            for treex in (tree1_copy, tree2_copy):
                for node in list(treex.traverse()):
                    if node.frequency > 0:
                        child = TreeNode()
                        child.add_feature('sequence', node.sequence)
                        node.add_child(child)
            try:
                return tree1_copy.robinson_foulds(tree2_copy,
                                                  attr_t1='sequence',
                                                  attr_t2='sequence',
                                                  unrooted_trees=True)[0]
            except:
                return tree1_copy.robinson_foulds(tree2_copy,
                                                  attr_t1='sequence',
                                                  attr_t2='sequence',
                                                  unrooted_trees=True,
                                                  allow_dup=True)[0]
        else:
            raise ValueError('invalid distance method: ' + method)

    def get_split(self, node):
        '''return the bipartition resulting from clipping this node's edge above'''
        if node.get_tree_root() != self.tree:
            raise ValueError('node not found')
        if node == self.tree:
            raise ValueError('this node is the root (no split above)')
        parent = node.up
        taxa1 = []
        for node2 in node.traverse():
            if node2.frequency > 0 or node2 == self.tree:
                if isinstance(node2.name, str):
                    taxa1.append(node2.name)
                else:
                    taxa1.extend(node2.name)
        taxa1 = set(taxa1)
        node.detach()
        taxa2 = []
        for node2 in self.tree.traverse():
            if node2.frequency > 0 or node2 == self.tree:
                if isinstance(node2.name, str):
                    taxa2.append(node2.name)
                else:
                    taxa2.extend(node2.name)
        taxa2 = set(taxa2)
        parent.add_child(node)
        assert taxa1.isdisjoint(taxa2)
        assert taxa1.union(taxa2) == set(
            (name for node in self.tree.traverse()
             if node.frequency > 0 or node == self.tree for name in ((
                 node.name, ) if isinstance(node.name, str) else node.name)))
        return tuple(sorted([taxa1, taxa2]))

    @staticmethod
    def split_compatibility(split1, split2):
        diff = split1[0].union(split1[1]) ^ split2[0].union(split2[1])
        if diff:
            raise ValueError(
                'splits do not cover the same taxa\n\ttaxa not in both: {}'.
                format(diff))
        for partition1 in split1:
            for partition2 in split2:
                if partition1.isdisjoint(partition2):
                    return True
        return False

    def support(self, bootstrap_trees_list, weights=None, compatibility=False):
        '''
        compute support from a list of bootstrap GCtrees
        weights (optional) is needed for weighting parsimony degenerate trees
        compatibility mode counts trees that don't disconfirm the split
        '''
        for node in self.tree.get_descendants():
            split = self.get_split(node)
            support = 0
            compatibility_ = 0
            for i, tree in enumerate(bootstrap_trees_list):
                compatible = True
                supported = False
                for boot_node in tree.tree.get_descendants():
                    boot_split = tree.get_split(boot_node)
                    if compatibility and compatible and not self.split_compatibility(
                            split, boot_split):
                        compatible = False
                    if not compatibility and not supported and boot_split == split:
                        supported = True
                if supported:
                    support += weights[i] if weights is not None else 1
                if compatible:
                    compatibility_ += weights[i] if weights is not None else 1
            node.support = compatibility_ if compatibility else support

        return self
Exemplo n.º 10
0
    def simulate(
        self,
        sequence: str,
        seq_bounds: Tuple[Tuple[int, int], Tuple[int, int]] = None,
        fitness_function: Callable = lambda seq: 0.9,
        lambda0: List[np.float64] = [1],
        frame: int = None,
        N_init: int = 1,
        N: int = None,
        T: int = None,
        n: int = None,
        verbose: bool = False,
    ) -> TreeNode:
        r"""Simulate a neutral binary branching process with the mutation model, returning a :class:`ete3.Treenode` object.

        Args:
            sequence: root nucleotide sequence
            seq_bounds: ranges for two subsequences used as two parallel genes
            fitness_function: mean number offspring as a function of sequence
            lambda0: baseline mutation rate(s)
            frame: coding frame of starting position(s)
            N_init: initial naive abundnace
            N: maximum population size
            T: maximum generation time
            n: sample size
            verbose: print more messages
        """
        # Checking the validity of the input parameters:
        if N is not None and T is not None:
            raise ValueError(
                "Only one of N and T can be used. One must be None.")
        elif N is None and T is None:
            raise ValueError("Either N or T must be specified.")
        if N is not None and n is not None and n > N:
            raise ValueError("n ({}) must not larger than N ({})".format(n, N))

        # Planting the tree:
        tree = TreeNode()
        tree.dist = 0
        tree.add_feature("sequence", sequence)
        tree.add_feature("terminated", False)
        tree.add_feature("abundance", 0)
        tree.add_feature("time", 0)
        # add fitness attribute, interpreted as mean of offspring distribution
        tree.add_feature("fitness", fitness_function(tree.sequence))

        if N_init > 1:
            for _ in range(N_init):
                child = TreeNode()
                child.dist = 0
                child.add_feature("sequence", sequence)
                child.add_feature("abundance", 0)
                child.add_feature("terminated", False)
                child.add_feature("time", 0)
                # add fitness attribute, interpreted as mean of offspring distribution
                child.add_feature("fitness", fitness_function(child.sequence))
                tree.add_child(child)

        t = 0  # <-- time
        leaves_unterminated = N_init
        while (leaves_unterminated > 0
               and (leaves_unterminated < N if N is not None else True)
               and (t < max(T) if T is not None else True)):
            if verbose:
                print("At time:", t)
            t += 1
            list_of_leaves = list(tree.iter_leaves())
            random.shuffle(list_of_leaves)
            for leaf in list_of_leaves:
                # add fitness attribute, interpreted as mean of offspring distribution
                leaf.add_feature("fitness", fitness_function(leaf.sequence))
                if not leaf.terminated:
                    n_children = poisson(leaf.fitness).rvs()
                    leaves_unterminated += (
                        n_children - 1
                    )  # <-- this kills the parent if we drew a zero
                    if not n_children:
                        leaf.terminated = True
                    for child_count in range(n_children):
                        # If sequence pair mutate them separately with their own mutation rate:
                        if seq_bounds is not None:
                            mutated_sequence1 = self.mutate(
                                leaf.
                                sequence[seq_bounds[0][0]:seq_bounds[0][1]],
                                lambda0=lambda0[0],
                                frame=frame,
                            )
                            mutated_sequence2 = self.mutate(
                                leaf.
                                sequence[seq_bounds[1][0]:seq_bounds[1][1]],
                                lambda0=lambda0[1],
                                frame=frame,
                            )
                            mutated_sequence = mutated_sequence1 + mutated_sequence2
                        else:
                            mutated_sequence = self.mutate(leaf.sequence,
                                                           lambda0=lambda0[0],
                                                           frame=frame)
                        child = TreeNode()
                        child.dist = utils.hamming_distance(
                            mutated_sequence, leaf.sequence)
                        child.add_feature("sequence", mutated_sequence)
                        child.add_feature("abundance", 0)
                        child.add_feature("terminated", False)
                        child.add_feature("time", t)
                        leaf.add_child(child)

        if N is not None and leaves_unterminated < N:
            raise RuntimeError(
                "tree terminated with {} leaves, {} desired".format(
                    leaves_unterminated, N))

        # each leaf in final generation gets an observed abundance of 1, unless downsampled
        if T is not None and len(T) > 1:
            # Iterate the intermediate time steps:
            for Ti in sorted(T)[:-1]:
                # Only sample those that have been 'sampled' at intermediate sampling times:
                final_leaves = [
                    leaf for leaf in tree.iter_descendants()
                    if leaf.time == Ti and leaf.sampled
                ]
                if len(final_leaves) < n:
                    raise RuntimeError(
                        "tree terminated with {} leaves, less than what desired after downsampling {}"
                        .format(leaves_unterminated, n))
                for (leaf) in (
                        final_leaves
                ):  # No need to down-sample, this was already done in the simulation loop
                    leaf.abundance = 1
        # Do the normal sampling of the last time step:
        final_leaves = [leaf for leaf in tree.iter_leaves() if leaf.time == t]
        # by default, downsample to the target simulation size
        if n is not None and len(final_leaves) >= n:
            for leaf in random.sample(final_leaves, n):
                leaf.abundance = 1
        elif n is None and N is not None:
            for leaf in random.sample(final_leaves, N):
                leaf.abundance = 1
        elif N is None and T is not None:
            for leaf in final_leaves:
                leaf.abundance = 1
        elif n is not None and len(final_leaves) < n:
            raise RuntimeError(
                "tree terminated with {} leaves, less than what desired after downsampling {}"
                .format(leaves_unterminated, n))
        else:
            raise RuntimeError("Unknown option.")

        # prune away lineages that are unobserved
        for node in tree.iter_descendants():
            if sum(node2.abundance for node2 in node.traverse()) == 0:
                node.detach()

        # # remove unobserved unifurcations
        # for node in tree.iter_descendants():
        #     parent = node.up
        #     if node.abundance == 0 and len(node.children) == 1:
        #         node.delete(prevent_nondicotomic=False)
        #         node.children[0].dist = hamming_distance(node.children[0].sequence, parent.sequence)

        # assign unique names to each node
        for i, node in enumerate(tree.traverse(), 1):
            node.name = "simcell_{}".format(i)

        # return the uncollapsed tree
        return tree