Example #1
0
def parse_tree(tree_file):
    print "[+] Parsing tree..."
    tree = Tree(tree_file)
    ## for every node, initiate gain and loss counts (and go back and do the same to the root)
    for node in tree.get_tree_root().iter_descendants("preorder"):
        node.add_features(gain_count=0, loss_count=0)
    tree.get_tree_root().add_features(gain_count=0, loss_count=0)
    return tree
def main(arg1,arg2):
   
    tree1=Tree(arg1)
    
    tree2=Tree(arg2)
        
    node_midpoint = tree1.get_leaf_names()[0]
   
    tree1.set_outgroup(node_midpoint)
    
    tree2.set_outgroup(node_midpoint)
    
    
    t1, tree2=tree2.get_tree_root().children
    t1, tree1=tree1.get_tree_root().children
    count = 0
    tree1_order=dfs_assign([tree1],tree2)

    Num_splits1=0
    Num_splits2=0
    Num_shared=0
    shared=dict()
    for node in tree1_order:
       
        if(node.is_leaf()==False):
            Num_splits1+=1
            subtree=node.get_leaf_names()
            cmin=min(subtree)
            cmax=max(subtree)
            if((node.is_root()==False)):
                shared["["+str(cmin)+":"+str(cmax)+"]"]=1
    
    for node in dfs_original([tree2]):
        if(node.is_leaf()==False):
            Num_splits2+=1
          
            size=0
            subtree=node.get_leaf_names()
            cmin=min(subtree)
            cmax=max(subtree)
            size=len(subtree)  
            if(size==(int(cmax)-int(cmin)+1)):
               if("["+str(cmin)+":"+str(cmax)+"]" in shared):
                   Num_shared+=1

   
    rf_dist=Num_splits1+Num_splits2-(2*Num_shared)
   

    return rf_dist
Example #3
0
def scale_tree_to_length(tree,
                         target_dist,
                         outgroup_name=None,
                         scaling_method=ASR_SCALING_MODE):
    """
	:param tree:  newick tree string or txt file containing one tree OR ete3.Tree object
	:param target_dist: numeric, the desired total tree distance
	:param outgroup_name: the name of a tree node if tree needs to be rooted,
	                      otherwise would calculate the distance from the inferred root (acoorsding to the newick order)
	:param scaling_method: "height" for longest distance from root to leaf, "tbl" for total branch lengths
	:return: a newick string of the rescaled tree
	"""
    t = Tree(get_newick_tree(tree), format=1)

    if outgroup_name:  # re-root tree
        t = reroot_tree(t, outgroup_name)
    root = t.get_tree_root()

    if scaling_method.lower() == "tbl":
        dist = get_total_branch_lengths(root)
        dist2 = calc_branch_length(tree)[0]
    elif scaling_method.lower() == "height":
        dist = get_tree_height(root)

    scaling_factor = target_dist / dist
    rescaled_tree = rescale_tree_branch_lengths(tree, scaling_factor)

    return rescaled_tree
Example #4
0
def draw_tree(recipe_inst):
    '''
    from ete3 import Tree
    recipe_inst = [{'word': 'heated', 'ingredient':['rice','banana','cookie','dishes']},
                   {'word': 'boil', 'ingredient':['apple','banana','cookie','dish']},
                   {'word': 'rince', 'ingredient':['apple','banana','cookie','dish']}
                  ]
    '''
    # sorting will not improve the tree edit distance
    # if sort:
    #    recipe_inst = [{'word':line['word'], 'ingredient': sorted(line['ingredient'])} for line in recipe_inst]

    output = Tree()
    temp = output
    for i in recipe_inst:
        t = Tree(name=i['word'])
        t.add_feature('type', 'action')
        if not i['ingredient']:
            pass
        else:
            for j in i['ingredient']:
                a = t.get_tree_root().add_child(name=j)
                a.add_feature('type', 'ingredient')
            temp = temp.add_child(t)
    print(output.get_ascii(show_internal=True))
    return output
Example #5
0
def parse_biopp_history(history_path):
    node_data_regex = re.compile("([^(|)]*?)\{(\d)\}")
    # read the tree from the file
    history = Tree(history_path, format=1)
    for node in history.traverse():
        if node != history.get_tree_root():
            node_name = (node_data_regex.search(node.name)).group(1)
            node_state = (node_data_regex.search(node.name)).group(2)
            node.name = node_name
            if node_state == "0":
                node.add_feature("label", "BG")
            else:
                node.add_feature("label", "FG")
        else:
            node.add_feature("label", "BG")  # root is always BG
        history.get_tree_root().name = "_baseInternal_30"
    return history
Example #6
0
def build_conv_topo(annotated_tree, vnodes):

      tconv = annotated_tree.copy(method="deepcopy")
      for n in tconv.iter_leaves():
        n.add_features(L=1)
      for n in tconv.traverse():
        n.add_features(COPY=0)
      # get the most recent ancestral node of all the convergent clades
      l_convergent_clades = tconv.search_nodes(T=True)
      common_anc_conv=tconv.get_common_ancestor(l_convergent_clades)

      # duplicate it at its same location (branch lenght = 0). we get
      # a duplicated subtree with subtrees A and B (A == B)

      dist_dup = common_anc_conv.dist
      if not common_anc_conv.is_root():
        dup_point = common_anc_conv.add_sister(name="dup_point",dist=0.000001)
        dup_point_root = False
      else:
        dup_point = Tree()
        dup_point_root = True
        dup_point.dist=0.000001

      dup_point.add_features(ND=0,T=False, C=False, Cz=False)

      common_anc_conv.detach()
      common_anc_conv_copy = common_anc_conv.copy(method="deepcopy")

      # tag duplicated nodes:

      for n in common_anc_conv_copy.traverse():
        n.COPY=1
        if n.ND not in vnodes and not n.is_root():
            n.dist=0.000001

      # pruned A from all branches not leading to any convergent clade
      l_leaves_to_keep_A = common_anc_conv.search_nodes(COPY=0, C=False, L=1)
      #logger.debug("A: %s",l_leaves_to_keep_A)
      common_anc_conv.prune(l_leaves_to_keep_A, preserve_branch_length=True)

      # pruned B from all branches not leading to any non-convergent clade
      l_leaves_to_keep_B = common_anc_conv_copy.search_nodes(COPY=1, C=True, L=1)
      #logger.debug("B : %s", l_leaves_to_keep_B)
      common_anc_conv_copy.prune(l_leaves_to_keep_B, preserve_branch_length=True)


      dup_point.add_child(common_anc_conv_copy)
      dup_point.add_child(common_anc_conv)

      tconv = dup_point.get_tree_root()

      nodeId = 0
      for node in tconv.traverse("postorder"):
          node.ND = nodeId
          nodeId += 1

      return tconv
def treeorder(treefile):
    from ete3 import Tree
    from ete3.treeview import faces, TreeStyle, NodeStyle, AttrFace
    t = Tree(treefile)
    rt = t.get_tree_root()
    nameorder = []
    for desc in rt.iter_descendants("preorder"):
        if not desc.is_leaf():
            continue
        nameorder.append(desc.name)
    return nameorder
Example #8
0
def merge_trees(str_arc_tree: str, str_bac_tree: str):
    str_arc_tree = gtdb_format_names(str_arc_tree)

    dendro_tree_arc = dendropy.Tree.get_from_string(str_arc_tree,
                                                    schema='newick',
                                                    rooting='force-rooted',
                                                    preserve_underscores=True)

    max_length = 0
    for edge in dendro_tree_arc.postorder_edge_iter():
        if edge.length:
            if edge.length > max_length:
                max_length = edge.length

    str_bac_tree = gtdb_format_names(str_bac_tree)

    dendro_tree_bac = dendropy.Tree.get_from_string(str_bac_tree,
                                                    schema='newick',
                                                    rooting='force-rooted',
                                                    preserve_underscores=True)

    for edge in dendro_tree_bac.postorder_edge_iter():
        if edge.length:
            if edge.length > max_length:
                max_length = edge.length

    ete_tree_arc = Tree(dendro_tree_arc.as_string(schema="newick",
                                                  suppress_rooting=True),
                        format=1,
                        quoted_node_names=True)
    ete_tree_bac = Tree(dendro_tree_bac.as_string(schema="newick",
                                                  suppress_rooting=True),
                        format=1,
                        quoted_node_names=True)

    ete_tree = Tree(name='root')

    ete_tree.add_child(ete_tree_arc.get_tree_root(), dist=max_length)
    ete_tree.add_child(ete_tree_bac.get_tree_root(), dist=max_length)

    return ete_tree
Example #9
0
def convert_history_to_simmap(tree_path, output_path):
    label_regex = re.compile("{(.*?)}")
    history = Tree(tree_path, format=1)
    #tree_str = history.write(format=8, outfile=None) # get a string of the tree newick only with names
    visited_nodes = [history.get_tree_root()]
    for node in history.traverse("postorder"):
        if not node in visited_nodes:
            # print("node name: ", node.name)
            node_label = label_regex.search(node.name).group(1)
            node_name = node.name.replace(
                label_regex.search(node.name).group(0), "")
            node_expression = ""
            if node.is_leaf():
                node_expression = node_expression + node_name
            node_expression = node_expression + ":{" + node_label + "," + str(
                node.dist)
            visited_nodes.append(node)
            if not node == history.get_tree_root():
                try:
                    curNode = node.up
                    while "mapping" in curNode.name:
                        node_label = label_regex.search(curNode.name).group(1)
                        node_expression = node_expression + ":" + node_label + "," + str(
                            curNode.dist)
                        visited_nodes.append(curNode)
                        curNode = curNode.up
                except:
                    pass
            node_expression = node_expression + "}"
            node.name = node_expression
            #tree_str = tree_str.replace(node.name, node_expression)
    # remove mapping nodes
    for node in history.traverse("postorder"):
        if "mapping" in node.name:
            node.delete()
    # write tree on your own because ete3 writer has a bug
    tree_str = getTreeStr(history.get_tree_root())
    # print("tree_str: ", tree_str)
    with open(output_path, "w") as output_file:
        output_file.write(tree_str + ";")
Example #10
0
def rescale_tree_branch_lengths(tree, factor):
    """
	:param tree: newick tree string or txt file containing one tree
	:param factor: the factor by which to multiply all branch lengths in tree
	:return:	reformatted_tree: a string of the scaled tree in Newick format
	"""
    if type(tree) == str:
        tree = Tree(get_newick_tree(tree), format=1)
    tree_root = tree.get_tree_root()
    for node in tree_root.iter_descendants(
    ):  # the root dist is 1.0, we don't want it
        node.dist = node.dist * factor
    return tree.write(format=1, dist_formatter="%.10f")
Example #11
0
def get_branch_lengths(tree):
    """
	:param tree: Tree node or tree file or newick tree string;
	:return: total branch lengths
	"""
    # TBL
    if type(tree) == str:
        tree = Tree(get_newick_tree(tree), format=1)
    tree_root = tree.get_tree_root()
    branches = []
    for node in tree_root.iter_descendants(
    ):  # the root dist is 1.0, we don't want it
        branches.append(node.dist)
    return branches
Example #12
0
def Check_Tree(tre_):
    return_tr = None
    ''' read a Newick string or an instance of Tree class 
	and return a Tree object with all the nodes having names '''
    ''' check if the input is a Newick string'''
    if isinstance(tre_, str):
        return_tr = Tree(tre_, format=3)
    else:
        return_tr = tre_
    nameless_count = 0
    if not return_tr.get_tree_root().name:
        return_tr.get_tree_root().name = 'diploid'
    for node in return_tr.traverse():
        if not node.name:
            nameless_count += 1
    new_ids = gen_unique_ids(nameless_count)
    id_idx = 0
    for node in return_tr.traverse():
        node.dist = 0
        if not node.name:
            node.name = new_ids[id_idx]
            id_idx += 1
    return return_tr
Example #13
0
def reroot(history_path, tree_path):
    history = Tree(history_path, format=1)
    tree = Tree(tree_path)
    if len(tree.get_children()) > 2:
        print("original tree is not rooted, no need to fix history")
        return
    missing_length = size(tree) - size(history)
    tree_children = tree.get_children()
    if abs(tree_children[0].dist-missing_length) < abs(tree_children[1].dist-missing_length):
        missing_child = tree_children[0]
    else:
        missing_child = tree_children[1]
    grandchildren_dists = [child.dist for child in missing_child.get_children()]
    children_to_detach = []
    labels = []
    for child in history.get_children():
        if child.dist in grandchildren_dists:
            if "{0}" in child.name:
                labels.append(0)
            else:
                labels.append(1)
            children_to_detach.append(child.detach())
    remaining_child = history.get_children()[-1]
    if "{0}" in remaining_child.name:
        labels.append(0)
    else:
        labels.append(1)
    if np.sum(labels) < 2:
        chosen_label = 0
    else:
        chosen_label = 1
    new_child = history.get_tree_root().add_child(name="missing_node{" + str(chosen_label) + "}", dist=missing_length)
    for child in children_to_detach:
        new_child.add_child(child)
    # make sure that now the tree and the history are of the same length
    if abs(size(history)-size(tree)) > 0.00001:
        print("Error! failed to fix history tree")
        print("size(history) = ", size(history) , "\n size(tree) = ", size(tree))
        exit (1)
    # make sure that all the lengths were written to tree string
    history_str = history.write(outfile=None, format=1)
    new_history = Tree(history_str, format=1)
    if abs(size(new_history)-size(tree)) > 0.00001:
        print("Error! failed to fix history tree newick format")
        print("size(history) = ", size(history) , "\n size(tree) = ", np.sum(bls))
        exit (1)
    # if all went well, write the fixed history to its original file
    history.write(outfile=history_path, format=1)
Example #14
0
def create_base_tree(input_tree_path, output_path):
    tree = Tree(input_tree_path, format=1)
    for node in tree.traverse():
        if "mapping" in node.name:
            node.get_children(
            )[0].dist += node.dist  # added the length of the removed branch to its single child
            node.delete()
        else:
            node_name = node.name
            node_name = node_name.replace("{0}", "")
            node_name = node_name.replace("{1}", "")
            node.name = node_name
    # give names to internal nodes
    tree.get_tree_root().name = "root"
    tree.write(outfile=output_path, format=1)
    return output_path
Example #15
0
def build_Protonation_Tree(peptide, residue_states):

    print("Building Protonation Trees from peptide %s" % list(peptide))
    expand_tree = True
    Peptide_Tree = Tree()
    Root = Peptide_Tree.get_tree_root()
    Root.add_feature("name", "root")
    Root.add_feature("state", "delete")
    level = 0
    sys.stdout.write("Expanding tree from level ")
    while level < len(peptide):
        sys.stdout.write(str(level) + " ")
        sys.stdout.flush()
        Peptide_Tree, expand_tree = populate_leaves(Peptide_Tree,
                                                    peptide[level],
                                                    residue_states)
        level += 1
    # Print the Tree
    # print Peptide_Tree.get_ascii(show_internal=True, compact=False)
    # print Peptide_Tree.get_ascii(show_internal=True, compact=False, attributes=["name", "dist", "occupancy", "numOfResonances"])

    print("\nSaving protonations from Tree...")

    all_protonations_set = set()
    for leaf in Peptide_Tree.iter_leaves():
        protonations = []
        resid, chain = leaf.name.split(".")
        protonations.append((leaf.state, resid, chain))
        for ancestor in leaf.get_ancestors()[:-1]:  # skip the root
            resid, chain = ancestor.name.split(".")
            protonations.append((ancestor.state, resid, chain))
        protonations.sort(
            key=itemgetter(2, 1)
        )  # sort by chain and resid to avoid permutations of the same combination
        protonations = tuple(
            ["%s_%s.%s" % (t[0], t[1], t[2]) for t in protonations])
        all_protonations_set.add(protonations)
        del protonations
        del ancestor
        del leaf
        # Peptide_Tree = None
    del Peptide_Tree
    gc.collect()
    return all_protonations_set
Example #16
0
def plot_simulation(input_simu, args_output, args_prefs, args_beta):
    if args_prefs != "":
        x = pd.read_csv(args_prefs, sep=",").drop('site', axis=1).values
        for i, row in enumerate(x):
            p = np.power(row, args_beta)
            x[i, :] = p / np.sum(p)
        y = np.array(aa_freq_from_ali(input_simu.replace(".nhx", ".ali")))
        plot_xy([shanon_diversity(row) for row in x],
                [shanon_diversity(row) for row in y],
                "{0}/correlation.prefs.shannon".format(args_output), "entropy")

        plot_xy(x.flatten(), y.flatten(),
                "{0}/correlation.prefs".format(args_output), "frequency")

    t = Tree(input_simu, format=1)

    args_nodes = set()
    for node in t.traverse():
        args_nodes = args_nodes.union(node.features)

    branch_dict = {}
    for arg in args_nodes:
        if arg != "dist" and arg != "support":
            values = np.array([
                float(getattr(n, arg)) for n in t.traverse()
                if arg in n.features and convertible_to_float(getattr(n, arg))
            ])
            if len(values) > 1 and len(values) == len(list(t.traverse())):
                root_pop_size = float(getattr(t.get_tree_root(), arg))
                for n in t.traverse():
                    n.add_feature(
                        "Log" + arg,
                        np.log(float(getattr(n, arg)) / root_pop_size))
                plot_tree(t, "Log" + arg,
                          "{0}/tree.{1}.pdf".format(args_output, arg))
            if len(values) > 1 and ("Branch" in arg) and (("dNd" in arg) or
                                                          ("LogNe" in arg)):
                branch_dict[arg] = values

    plot_correlation("{0}/correlation.Ne.dNdS.pdf".format(args_output),
                     branch_dict, {},
                     global_min_max=False)
def print_random_tree(num_nodes=5):
    """
    Doc Doc Doc
    """

    t = Tree()
    t.populate(num_nodes)

    print("t", t)
    print("children", t.children)
    print("get_children", t.get_children())
    print("up", t.up)
    print("name", t.name)
    print("dist", t.dist)
    print("is_leaf", t.is_leaf())
    print("get_tree_root", t.get_tree_root())
    print("children[0].get_tree_root", t.children[0].get_tree_root())
    print("children[0].children[0].get_tree_root",
          t.children[0].children[0].get_tree_root())
    for leaf in t:
        print(leaf.name)
Example #18
0
def parse_union_tree(history_1, history_2, base_tree_path, debug=False):
    base_tree = Tree(base_tree_path, format=1)
    # add for debugging
    base_tree.get_tree_root().name = "_baseInternal_30"
    united_tree = Tree()
    united_tree.dist = 0  # initialize distance to 0
    united_tree.get_tree_root().name = history_1.get_tree_root(
    ).name  # set the name of the root
    united_tree.add_feature("history_1_label", history_1.get_tree_root().label)
    united_tree.add_feature("history_2_label", history_2.get_tree_root().label)
    union_nodes_number = 0
    for original_node in base_tree.traverse(
            "preorder"
    ):  # traverse the tree in pre-order to assure that for any visited node, its parent from the base branch is already in the united tree
        original_parent = original_node.up
        if original_parent != None:  # will be none only in the case the original node is the root
            if debug:
                print("handled branch: (", original_node.name, ",",
                      original_parent.name, ")")
            curr_union_parent = united_tree.search_nodes(
                name=original_parent.name)[0]
            hist_1_done = True
            hist_1_curr_child = None
            hist_1_parent = history_1.search_nodes(name=original_parent.name)[
                0]  # need to check names consistency across the 3 trees
            for child in hist_1_parent.children:
                if len(base_tree.search_nodes(name=child.name)) == 0 and len(
                        child.search_nodes(name=original_node.name)
                ) > 0:  # if the child is a root in a tree that holds the original child node, then this child must be on the branch of interest
                    hist_1_curr_child = child
                    hist_1_done = False
                    break
            if hist_1_done:
                hist_1_curr_child = history_1.search_nodes(
                    name=original_node.name)[0]
            hist_1_current_label = hist_1_curr_child.label

            hist_2_done = True
            hist_2_curr_child = None
            hist_2_parent = history_2.search_nodes(name=original_parent.name)[
                0]  # need to check names consistency across the 3 trees
            for child in hist_2_parent.children:
                if len(base_tree.search_nodes(name=child.name)) == 0 and len(
                        child.search_nodes(name=original_node.name)
                ) > 0:  # if the child is a root in a tree that holds the original child node, then this child must be on the branch of interest
                    hist_2_curr_child = child
                    hist_2_done = False
                    break
            if hist_2_done:
                hist_2_curr_child = history_2.search_nodes(
                    name=original_node.name)[0]
            hist_2_current_label = hist_2_curr_child.label

            while not hist_1_done or not hist_2_done:

                hist_1_dist = float("inf")
                hist_2_dist = float("inf")
                if not hist_1_done:  # if there is a node closer to the original node in history 1 -> add it to the united tree first
                    hist_1_dist = hist_1_curr_child.get_distance(
                        original_parent.name) - curr_union_parent.get_distance(
                            original_parent.name)
                if not hist_2_done:
                    hist_2_dist = hist_2_curr_child.get_distance(
                        original_parent.name) - curr_union_parent.get_distance(
                            original_parent.name)

                if debug:
                    if not hist_1_done:
                        print("history 1 has current child of ",
                              original_parent.name, ": ",
                              hist_1_curr_child.name, " with label: ",
                              hist_1_current_label,
                              " and distance from parent is: ", hist_1_dist)
                    if not hist_2_done:
                        print("history 2 has current child of ",
                              original_parent.name, ": ",
                              hist_2_curr_child.name, " with label: ",
                              hist_2_current_label,
                              " and distance from parent is: ", hist_2_dist)

                # first, check if now the two current children have the same name, and if this name is in the base tree - exit
                if hist_1_curr_child.name == hist_2_curr_child.name and len(
                        base_tree.search_nodes(
                            name=hist_1_curr_child.name)) > 0:
                    break

                # else, at least one of the histories has more than one step to go before reaching the bottom of the branch
                if hist_1_dist < hist_2_dist:  # add the node from history 1 and travel down to the next node in history 1
                    if debug:
                        print(
                            "adding child from history 1 which precedes to the one from history 2"
                        )
                        print("the label of the added node in history 1 is: ",
                              hist_1_curr_child.label)
                        print(
                            "the label of the added node in histroy 2 remains like papa: ",
                            hist_2_current_label)
                    curr_union_parent = curr_union_parent.add_child(
                        child=None,
                        name="internal_" + str(union_nodes_number),
                        dist=hist_1_dist,
                        support=None)
                    curr_union_parent.add_feature("history_1_label",
                                                  hist_1_curr_child.label)
                    curr_union_parent.add_feature("history_2_label",
                                                  hist_2_current_label)
                    hist_1_parent = hist_1_curr_child
                    if len(hist_1_parent.children) == 1:
                        hist_1_curr_child = hist_1_parent.children[0]
                    else:
                        hist_1_done = True
                    if debug:
                        print("united tree is now: \n", united_tree)
                        if hist_1_done:
                            print(
                                "history 1 on the handled branch is complete")
                        else:
                            print(
                                "history 1 on the handled branch isn't complete yet"
                            )

                else:  # add the node from history 2 and travel down to the next node in history 2
                    if debug:
                        print(
                            "adding child from history 2 which precedes to the one from history 1"
                        )
                        print("the label of the added node in history 2 is: ",
                              hist_2_curr_child.label)
                        print(
                            "the label of the added node in history 1 remains like papa: ",
                            hist_1_current_label)
                    curr_union_parent = curr_union_parent.add_child(
                        child=None,
                        name="internal_" + str(union_nodes_number),
                        dist=hist_2_dist)  # added as a new branch
                    curr_union_parent.add_feature("history_1_label",
                                                  hist_1_current_label)
                    curr_union_parent.add_feature("history_2_label",
                                                  hist_2_curr_child.label)
                    hist_2_parent = hist_2_curr_child
                    if len(hist_2_parent.children) == 1:
                        hist_2_curr_child = hist_2_parent.children[0]
                    else:
                        hist_2_done = True
                    if debug:
                        print("united tree is now: \n", united_tree)
                        if hist_2_done:
                            print(
                                "history 2 on the handled branch is complete")
                        else:
                            print(
                                "history 2 on the handled branch isn't complete yet"
                            )
                union_nodes_number += 1

            # now add the original node as the child of the current parent
            original_dist = original_node.dist
            residual = original_dist - curr_union_parent.get_distance(
                united_tree.search_nodes(name=original_parent.name)[0])
            curr_union_parent = curr_union_parent.add_child(
                child=None, name=original_node.name, dist=residual)
            curr_union_parent.add_feature(
                "history_1_label",
                history_1.search_nodes(name=original_node.name)[0].label)
            curr_union_parent.add_feature(
                "history_2_label",
                history_2.search_nodes(name=original_node.name)[0].label)

    return united_tree
                        "--compounds",
                        help="metabolites file one per line (uncoded)",
                        required=True)

    args = parser.parse_args()

    tree_file = args.tree
    compounds_file = args.compounds
    outfile = args.json

    with open(compounds_file, 'r') as f:
        compounds = [i.strip("\n") for i in f.readlines()]

    ctree = Tree(tree_file, format=8)
    # p.get_tree_root().name = "Chemicals"
    ctree.get_tree_root().name = "Compounds"

    compounds_parent = {}
    pb_compounds = []
    no_ancestor = []

    for elem in compounds:
        try:
            paths = [[a.name for a in i.get_ancestors()]
                     for i in ctree.search_nodes(name=elem)
                     ]  # GLC Glucopyranose
            if all(i == ['Compounds'] for i in paths):
                no_ancestor.append(elem)
            else:
                compounds_parent[elem] = [
                    list(reversed(path))[1] for path in paths
Example #20
0
from ete3 import Tree
t = Tree()
# We create a random tree topology
t.populate(15)
print t
print t.children
print t.get_children()
print t.up
print t.name
print t.dist
print t.is_leaf()
print t.get_tree_root()
print t.children[0].get_tree_root()
print t.children[0].children[0].get_tree_root()
# You can also iterate over tree leaves using a simple syntax
for leaf in t:
      print leaf.name
Example #21
0
class MCTS:

    def __init__(self, game, nnet, args):  # remove env? Need to re-pass it in every move - not just the first move
        self.env = game
        self.nnet = nnet
        self.args = args
        self.Qsa = {}  # stores Q values for s,a (as defined in the paper)
        self.Nsa = {}  # stores #times edge s,a was visited
        self.Ns = {}  # stores #times board s was visited
        self.Ps = {}  # stores initial policy (returned by neural net)
        self.Usa = {}

        # for plotting a tree diagram
        self.tree = Tree()

    def get_action_prob(self, state_2d, root_state, agent, temp=1):
        """
        This function performs numMCTSSims simulations of MCTS starting from
        canonicalBoard.
        Returns:
            probs: a policy vector where the probability of the ith action is
                   proportional to Nsa[(s,a)]**(1./temp)
        """
        # no point in doing mcts if its random anyways
        for i in range(self.args.numMCTSSims):
            self.env.reset(root_state)
            self.search(state_2d, agent, done=False)

        self.env.reset(root_state)
        s = self.env.get_mcts_state(root_state, g_accuracy)
        counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.env.get_action_size(agent))]

        if temp == 0:
            bestA = np.argmax(counts)
            probs = [0] * len(counts)
            probs[bestA] = 1
            return probs

        counts = [x ** (1.0 / temp) for x in counts]
        probs = [x / float(sum(counts)) for x in counts]
        return probs

    def search(self, state_2d, agent, done):
        """
        This function performs one iteration of MCTS. It is recursively called
        till a leaf node is found. The action chosen at each node is one that
        has the maximum upper confidence bound as in the paper.
        Once a leaf node is found, the neural network is called to return an
        initial policy P and a value v for the state. This value is propagated
        up the search path. In case the leaf node is a terminal state, the
        outcome is propagated up the search path. The values of Ns, Nsa, Qsa are
        updated.
        NOTE: the return values are the negative of the value of the current
        state. This is done since v is in [-1,1] and if v is the value of a
        state for the current agent, then its value is -v for the other agent.
        Returns:
            v: the negative of the value of the current canonicalBoard

        *** remember to convert v to float if bringing in a new definition
        """
        agent %= 2
        state = self.env.get_state()
        s = self.env.get_mcts_state(state, g_accuracy)
        # ---------------- TERMINAL STATE ---------------
        if done:
            if self.env.mcts_steps >= self.env.steps_till_done:
                _, v = self.nnet.predict(state_2d, agent)
                # print("done and got to the end at step: ", curr_env.steps, " and value: ", v)
                return v
            # print("Done, at step: ", curr_env.steps, " returning: ", curr_env.terminal_cost)
            # return value as if fallen over
            t = float(self.env.terminal_cost) # * ((self.env.steps_till_done - self.env.mcts_steps)/self.env.steps_till_done + 1))
            return t

        # ------------- EXPLORING FROM A LEAF NODE ----------------------
        # check if the state has been seen before. If not then assign Ps[s]
        # a probability for each action, eg Ps[s1] = [0.25, 0.75] for a = [0(left) 1(right)]
        # Note, we do not take an action here. Just get an initial policy
        # Also get the state value - work this out later
        if s not in self.Ps:
            pi, v = self.nnet.predict(state_2d, agent)
            self.Ps[s] = pi   # list
            self.Ns[s] = 0
            return v

        # ------------- GET agent AND ADVERSARY ACTIONS -----------------------------
        # search through the valid actions and update the UCB for all actions then update best actions
        # pick the action with the highest upper confidence bound
        cur_best = -float('inf')  # set current best ucb to -inf
        best_act = None  # null action
        for a in range(self.env.get_action_size(agent)):
            if (s, a) in self.Qsa:
                q = self.Qsa[(s, a)] if agent == 0 else -self.Qsa[(s, a)]
                u = q + self.args.cpuct * self.Ps[s][a]*np.sqrt(self.Ns[s])/(1+self.Nsa[(s, a)])
            else:
                u = self.args.cpuct * self.Ps[s][a]*np.sqrt(self.Ns[s] + EPS)
            self.Usa[(s, a)] = u
            if u > cur_best:
                cur_best = u
                best_act = a
        a = best_act
        # ----------- RECURSION TO NEXT STATE ------------------------
        next_state, loss, next_done, _ = self.env.step(a, agent)      # not a true step -> will update mcts_steps
        if self.args.mctsTree:
            self.add_tree_node(state, next_state, a, agent)
        next_state_2d = self.env.get_state_2d(prev_state_2d=state_2d)
        v = self.search(next_state_2d, agent+1, next_done)

        # ------------ BACKUP Q-VALUES AND N-VISITED -----------------
        # after we reach the terminal condition then the stack unwinds and we
        # propagate up the tree backing up Q and N as we go
        if (s, a) in self.Qsa:
            self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
            self.Nsa[(s, a)] += 1

        else:
            self.Qsa[(s, a)] = v
            self.Nsa[(s, a)] = 1

        self.Ns[s] += 1
        return v

    def add_tree_node(self, prev_state, curr_state, a, agent):
        """
        Updates the tree with each new node's
            "name, s_t" - the mcts_state e.g. (699923423, 124134235, 234234124, 23524634634)
            "agent, p_t" - the agent at time t
            "action, a_(t-1)" - the action needed to get from prev_state to the curr_state (the previous agents action)
        Note, normally a node is stored as (s, a) = (curr_state, action from curr_state)...
        However, this would mean the tree has to be 2x as big. Instead use (curr_state, action to curr_state).
        Luckily, ete3 always puts +1 on the left!
        """
        parent_mcts_state = self.env.get_mcts_state(prev_state, g_accuracy)
        child_mcts_state = self.env.get_mcts_state(curr_state, g_accuracy)

        prev_len = len(self.tree.search_nodes(name=parent_mcts_state))
        curr_len = len(self.tree.search_nodes(name=child_mcts_state))

        # note since the action is taken from prev->curr, store agent as the agent that took the action to get to the current state.
        # if neither the parent or the child node exist then this must be a root
        if curr_len == 0 and prev_len == 0:
            parent = self.tree.get_tree_root()
            parent.name = parent_mcts_state; parent.dist = 5
            child = parent.add_child(name=child_mcts_state, dist=5)
            parent.add_features(action=None, agent=agent)  # the root node doesn't have an action that got to it
            child.add_features(action=a, agent=(agent+1) % 2)
            return

        # if there are no nodes with curr_mcts_state already in existence then find the parent and add a child
        parent = self.tree.search_nodes(name=parent_mcts_state)[0]
        if len(self.tree.search_nodes(name=child_mcts_state)) == 0:
            child = parent.add_child(name=child_mcts_state, dist=5)
            child.add_features(action=a, agent=(agent+1) % 2)

        # if child and parent already exist then this must be another simulation -> dont need to add another node
        # if we have just taken the next true step, then change the node style
        if self.env.mcts_steps <= self.env.steps+1:
            nstyle = NodeStyle()
            nstyle["size"] = 0
            nstyle["vt_line_color"] = "#ff0000"
            nstyle["vt_line_width"] = 8
            nstyle["vt_line_type"] = 0  # 0 solid, 1 dashed, 2 dotted
            nstyle["shape"] = "sphere"
            nstyle["size"] = 20
            nstyle["fgcolor"] = "darkred"
            parent.add_features(steps=self.env.steps)
            parent.set_style(nstyle)

    def update_tree_values(self):
        """
        Update the tree created in add_tree_node with values of Nsa, Ns, Ps, Qsa and Usa. Colour each of the nodes
        with a colour representing their values.
        Note that nodes on the tree are represented as (s_t, a_(t-1), p_t) which means that when calling Qsa, Usa
        or Nsa (ones involving a) we need to use (s_(t-1), a_(t-1), p_(t-1)), as opposed to Ns and Ps, where we use
        (s_t, p_t).
        """
        tree_itr = self.tree.traverse()
        root = next(tree_itr)   # skip the first two nodes the root and the root's root (not sure why there's 2?)
        root_face = TextFace(u'(s\u209C, a\u209C\u208B\u2081, agent\u209C) = ({}, {}, {}, Ns={})'.format(
            [float(dim) / g_accuracy for dim in root.name],
            root.action, root.agent, self.Ns[root.name]))
        root_face.background.color = '#FF0000'
        root.add_face(root_face, column=0, position="branch-top")

        for node in tree_itr:
            s, a, agent = node.up.name, node.action, node.up.agent
            p_from_a = node.agent
            state = [float(dim) / g_accuracy for dim in node.name]
            # note that (s, a) is the same as node.name (since parent state & action taken from there = child state)
            try:
                delattr(node, "_faces")     # need to remove previously added faces otherwise they stack
            except:
                pass
            _, v = self.nnet.predict(np.array(state), p_from_a)

            # -------- ADD ANNOTATION FOR STATE LOSS AND ACTION TAKEN ----------
            loss = self.env.state_loss(state=state)
            # unicode to get subscripts
            loss_face = TextFace(u'(x\u209C, u\u209C\u208B\u2081, agent\u209C) = ({0}, {1}, {2}),  c(x\u209C) = {3:.3f}, v_pred = {4:.3f}'.format(self.env.round_state(state=state), a, p_from_a, loss, v))
            c_loss = cm.viridis(255+int(loss*255))  # viridis goes from 0-255
            c_loss = "#{0:02x}{1:02x}{2:02x}".format(*[int(round(i * 255)) for i in [c_loss[0], c_loss[1], c_loss[2]]])
            loss_face.background.color = c_loss   # need rgb colour in hex, "#FFFFFF"=(255, 255, 255)
            node.add_face(loss_face, column=0, position="branch-top")

            # -------- ADD ANNOTATION FOR ACTION VALUE WRT agent, Q --------
            #print("s={}, a={}, agent={}".format(s, a, agent))
            if (s, a) in self.Qsa:

                #print(self.Ns[s])
                #print(self.Nsa[(s, a)])
                #print(self.Ps[s][a])
                q = self.Qsa[(s, a)] if agent == 0 else -self.Qsa[(s, a)]
                ucb = self.args.cpuct * self.Ps[s][a]*np.sqrt(self.Ns[s])/(1+self.Nsa[(s, a)])
                u = self.Usa[(s, a)]
            else:
                q = 0
                ucb = self.args.cpuct * self.Ps[node.name][a]*np.sqrt(self.Ns[s] + EPS)
                u = ucb
            q_formula = '(Q(x))' if agent == 0 else '(-Q(x))'
            QA_face = TextFace("U\u209C = {:.3f}{} + {:.3f}(ucb) = {:.3f}".format(q, q_formula, ucb, u))

            c_value = cm.viridis(255+int((q+ucb)*255))  # plasma goes from 0-255
            c_value = "#{0:02x}{1:02x}{2:02x}".format(
                *[int(round(i * 255)) for i in [c_value[0], c_value[1], c_value[2]]])
            QA_face.background.color = c_value
            node.add_face(QA_face, column=0, position="branch-bottom")

            # -------- ADD ANNOTATION FOR NUMBER OF VISITS -------
            # have to use node.name for printing Ns
            ns = 0 if node.name not in self.Ns else self.Ns[node.name]
            N_face = TextFace(" Nsa(x\u209C\u208B\u2081, u\u209C)={}, Ns(x)={}".format(self.Nsa[(s, a)], ns))

            c_vis = cm.YlGn(int(255*(1-self.Nsa[(s, a)]/(self.args.numMCTSSims*2))))  # YlGn goes from 0-255
            c_vis = "#{0:02x}{1:02x}{2:02x}".format(*[int(round(i * 255)) for i in [c_vis[0], c_vis[1], c_vis[2]]])
            N_face.background.color = c_vis
            node.add_face(N_face, column=1, position="branch-bottom")

    def show_tree(self):
        self.update_tree_values()  # only actually need to update these when showing the tree

        ts = TreeStyle()
        ts.show_leaf_name = False
        ts.show_branch_support = False
        # ts.rotation = 90
        # ts.title.add_face(TextFace("Hello ETE", fsize=20), column=0)
        # each node contains 3 attributes: node.dist, node.name, node.support
        self.tree.show(tree_style=ts)  # , show_internal=True)
Example #22
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to root a tree given an outgroup"""

import argparse
import os
import sys

from ete3 import Tree

parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tree')
parser.add_argument('-o', '--outgroup')
opts = parser.parse_args(sys.argv[1:])

# check arguments
if not os.path.isfile(opts.tree):
    sys.stderr.write("File {0} not found".format(opts.tree))
    sys.exit(1)

try:
    opts.outgroup
except NameError:
    sys.stderr.write("Outgroup must be defined (--outgroup)")
    sys.exit(1)

t = Tree(opts.tree)
t.set_outgroup(opts.outgroup)
print(t.get_tree_root().write(format=5))
#!/homes/carlac/anaconda_ete/bin/python

# Copyright [1999-2015] Wellcome Trust Sanger Institute and the EMBL-European Bioinformatics Institute
# Copyright [2016-2019] EMBL-European Bioinformatics Institute
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys, os
from ete3 import Tree

infile = sys.argv[1]
if not os.path.isfile(infile):
    sys.stderr.write("File %s not found", infile)
    sys.exit(1)

t = Tree(infile)
root = t.get_tree_root()
root.unroot()
print(root.write())
Example #24
0
		year_dict[acc_str]=year_str
	no_n=re.search(r'_N\d+', line)
	if no_n:
		no_n_str= no_n.group()	
		no_n_str=re.sub('_N','',no_n_str)
		no_N_dict[acc_str]=no_n_str
	
log_file.write("The number of clusters are:" + str(cluster_cnt))
log_file.close()
cdhit_file.close()

#print "Tree from FastTree program is being used to calculate root to leaf distances..."
#Passing in the tree generated by FastTree
FastTree=Tree(args.input2)
#Getting the root of the tree
root=FastTree.get_tree_root()
#Loop through each leaf of the tree
for leaf in FastTree:
	#Convert 'leaf' to string to allow manipulation
	leaf_str=str(leaf)
	acc_nu=re.search(r'\w{2}\d+.\d{1}_\d{4}|\w{2}_\d+.\d{1}_\d{4}',leaf_str)
	acc_nu=str(acc_nu.group())
	acc_nu=re.sub('_\d{4}$','',acc_nu)
	rt_lf=FastTree.get_distance(root,leaf)
	#Make a dictionary using acc_nu as key 
	branlength_dict[acc_nu]=rt_lf
	#Using the generated dictionaries to print the relevant information to a tab delimited file
	tsv_file.write(acc_nu + "\t" + year_dict[acc_nu] + "\t" + str(rt_lf) + "\t" + clust_dict[acc_nu] + "\t" + no_N_dict[acc_nu] + "\n")
tsv_file.close

        d = tree.get_distance(node, leaf)
        dists[leaf] = d
    sorted_dists = sorted(dists.items(), key=itemgetter(1))
    middle_node = sorted_dists[len(sorted_dists)/2][0]
    return middle_node

tree = Tree("((A:0.1,B:0.2):0.3,(C:0.5,D:0.1):0.05);") #or read in from a file
print tree
mean_root_to_tip = get_mean_root_to_tip(tree)

#divide mean distance into some number of contours
num_contours = 4
contours = []
for i in range(num_contours):
    print i+2
    contours.append(mean_root_to_tip/float(i+2))
print contours

#for each contour, print num of nodes for which one descendant will be picked
root = tree.get_tree_root()
for c in contours:
    to_keep = []
    for node in tree.traverse():
        if contour_node(root, node, c):
            node_to_keep = pick_average_tip(node)
            to_keep.append(node_to_keep)
    print "Contour at " + str(c) + ", " + str(len(to_keep)) + " in total."
    for taxon in to_keep:
        print taxon.name

Example #26
0
#!/homes/carlac/anaconda_ete/bin/python

# Copyright [1999-2015] Wellcome Trust Sanger Institute and the EMBL-European Bioinformatics Institute
# Copyright [2016-2019] EMBL-European Bioinformatics Institute
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#      http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys, os
from ete3 import Tree

infile = sys.argv[1]
if not os.path.isfile(infile):
	sys.stderr.write("File %s not found", infile)
	sys.exit(1)

t = Tree(infile)
root = t.get_tree_root()
root.unroot()
print(root.write(format=5))
Example #27
0
def main(arg1, arg2):

    t1 = Tree()
    tree1 = Tree(arg1)

    tree2 = Tree(arg2)

    node_midpoint = getRandomNode(tree1)

    tree1.set_outgroup(node_midpoint)

    tree2.set_outgroup(node_midpoint)

    t1, tree2 = tree2.get_tree_root().children
    t1, tree1 = tree1.get_tree_root().children
    count = 0
    for leaf in tree1.traverse("postorder"):
        if (leaf.name.strip()):
            count += 1
            leaf.add_features(order=count)
            CurrentNode2 = tree2 & leaf.name
            CurrentNode2.add_features(order=count)

        elif (leaf.name != node_midpoint):
            leaf.name = "int"

    for node in tree2.traverse("postorder"):
        if (node.name == ""):
            node.name = "int"

    Num_splits1 = 0
    Num_splits2 = 0
    Num_shared = 0
    for node in tree1.traverse("postorder"):

        if ((node.name == "int")):
            Num_splits1 += 1
            cmin = float("+inf")
            cmax = 0
            d1, d2 = node.get_children()
            subtree = Tree()
            subtree.add_child(d1)
            subtree.add_child(d2)

            for leaf in subtree:

                if ((leaf.name != "int")):
                    CurrentNode2 = tree1 & leaf.name
                    cmin = min(CurrentNode2.order, cmin)
                    cmax = max(CurrentNode2.order, cmax)

            if ((node.is_root() == False)):
                node.name = "[" + str(cmin) + ":" + str(cmax) + "]"

    for node in tree2.traverse("postorder"):

        if ((node.name == "int") and (node.is_root() == False)):
            Num_splits2 += 1
            cmin = float("+inf")
            cmax = 0
            size = 0
            d1, d2 = node.get_children()
            subtree2 = Tree()
            subtree2.add_child(d1)
            subtree2.add_child(d2)

            for leaf in subtree2:
                size += 1
                if ((leaf.name != "int") and (leaf.name != node_midpoint)):
                    CurrentNode2 = tree2 & leaf.name
                    cmin = min(CurrentNode2.order, cmin)
                    cmax = max(CurrentNode2.order, cmax)
            if (size == (cmax - cmin + 1)):
                node.name = "[" + str(cmin) + ":" + str(cmax) + "]"
            if (tree1.search_nodes(name=node.name)):
                Num_shared += 1
    global leaf_num
    leaf_num = len(tree2.get_leaves())

    rf_dist = Num_splits1 + Num_splits2 - (2 * Num_shared)

    return rf_dist
class TreeHolder:
    def __init__(self,
                 tree,
                 logger,
                 scale=None,
                 labels_dict=None,
                 node_colors=defaultdict(lambda: 'black')):
        self.tree = Tree(tree)
        self.scale = scale

        for node in self.tree.traverse():
            if len(node.children) == 3:
                logger.info("Trying to root tree by first child of root")
                logger.info(f'Children of root: {node.children}')
                self.tree.set_outgroup(node.children[0])
            break

        for node in self.tree.traverse():
            # Hide node circles
            node.img_style['size'] = 0

            if node.is_leaf():
                try:
                    name_face = TextFace(
                        labels_dict[node.name] if labels_dict else node.name,
                        fgcolor=node_colors[node.name])
                except KeyError:
                    msg = f'There is not label for leaf {node.name} in labels file'
                    logger.error(msg)
                    raise KeyError(msg)
                node.add_face(name_face, column=0)

    def draw(self,
             file,
             colors,
             color_internal_nodes=True,
             legend_labels=(),
             show_branch_support=True,
             show_scale=True,
             legend_scale=1,
             mode="c"):
        max_color = len(colors)

        for node in self.tree.traverse():
            if not (color_internal_nodes or node.is_leaf()): continue
            color = colors[min(node.color, max_color - 1)]
            node.img_style['bgcolor'] = color

        ts = TreeStyle()
        ts.mode = mode
        ts.scale = self.scale
        # Disable the default tip names config
        ts.show_leaf_name = False
        ts.show_branch_support = show_branch_support

        # ts.branch_vertical_margin = 20
        ts.show_scale = show_scale
        cur_max_color = max(v.color for v in self.tree.traverse())
        current_colors = colors[0:cur_max_color + 1]

        for i, (label, color_) in enumerate(zip(legend_labels,
                                                current_colors)):
            ts.legend.add_face(CircleFace(24 * legend_scale, color_), column=0)
            ts.legend.add_face(CircleFace(13 * legend_scale, 'White'),
                               column=1)
            ts.legend.add_face(TextFace(label, fsize=53 * legend_scale),
                               column=2)
            ts.legend.add_face(CircleFace(13 * legend_scale, 'White'),
                               column=3)

        # self.tree.render("ete_tree.pdf", dpi=300, tree_style=ts)
        self.tree.render(file, w=1000, tree_style=ts)

    def get_all_leafs(self):
        return {node.name for node in self.tree.get_leaves()}

    def count_innovations_fitch(self, leaf_colors):
        def assign_colorset_feature(v):
            if v.is_leaf():
                v.add_features(colorset={leaf_colors[v.name]},
                               color=leaf_colors[v.name])
            else:
                try:
                    child1, child2 = v.children
                except ValueError:
                    print(v.children)
                    raise ValueError('Tree must me binary')
                cs1 = assign_colorset_feature(child1)
                cs2 = assign_colorset_feature(child2)
                v.add_features(
                    colorset=(cs1 & cs2) if len(cs1 & cs2) > 0 else cs1 | cs2)

            return v.colorset

        def chose_color(colorset):
            return sorted(colorset,
                          key=lambda c: color_counter[c],
                          reverse=True)[0]

        def down_to_leaves(v, color):
            if v.is_leaf(): return
            v.add_features(color=color if color in
                           v.colorset else chose_color(v.colorset))
            for child in v.children:
                down_to_leaves(child, v.color)

        def count_innovations(v, innovations):
            for child in v.children:
                if v.color != child.color:
                    innovations[child.color].append(child)
                count_innovations(child, innovations)

        color_counter = Counter(leaf_colors.values())

        # get colorsets for internal nodes
        root = self.tree.get_tree_root()
        assign_colorset_feature(root)

        # get color for internal nodes
        root_color = chose_color(root.colorset)
        down_to_leaves(root, root_color)

        # get inconsistent colors
        self.innovations = defaultdict(list)
        count_innovations(root, self.innovations)

    def count_parallel_rearrangements(self, skip_grey):
        score, count, count_all = 0, 0, 0
        for color, nodes in self.innovations.items():
            if len(nodes) <= 1 or (skip_grey and color == 1): continue
            count += 1
            count_all += len(nodes)
            for n1, n2 in combinations(nodes, 2):
                score += n1.get_distance(n2)
        return score, count, count_all

    def count_parallel_breakpoints(self):
        count = sum(map(len, self.innovations.values()))
        score = sum(
            n1.get_distance(n2)
            for n1, n2 in combinations((n for ns in self.innovations.values()
                                        for n in ns), 2))
        return score, count

    def draw_coloring(self, file):
        for node in self.tree.traverse():
            node.img_style['bgcolor'] = self.colors[node.color]
        ts = TreeStyle()
        ts.show_leaf_name = False
        self.tree.render(file, w=1000, tree_style=ts)

    def prune(self, ls):
        self.tree.prune(list(ls))
                        "--pathways",
                        help="pathways file one per line (uncoded)",
                        required=True)

    args = parser.parse_args()

    tree_file = args.tree
    pathways_file = args.pathways
    outfile = args.json

    with open(pathways_file, 'r') as f:
        pathways = [i.strip("\n") for i in f.readlines()]

    ptree = Tree(tree_file, format=8)
    # p.get_tree_root().name = "Chemicals"
    ptree.get_tree_root().name = "Pathways"

    pathways_parent = {}
    pb_pathways = []
    no_ancestor = []

    for elem in pathways:
        try:
            paths = [[a.name for a in i.get_ancestors()]
                     for i in ptree.search_nodes(name=elem)
                     ]  # GLC Glucopyranose
            if all(i == ['Pathways'] for i in paths):
                no_ancestor.append(elem)
            else:
                pathways_parent[elem] = [
                    list(reversed(path))[2] for path in paths
Example #30
0
def scm(tree1, tree2):

    leaf_list1 = []
    leaf_list2 = []
    for leaf in tree1:
        leaf_list1.append(leaf.name)
    for leaf in tree2:
        leaf_list2.append(leaf.name)
    ###print(leaf_list1)
    overlap = intersection(leaf_list1, leaf_list2)
    ###print("overlap is: ",overlap)
    if (len(overlap) < 3):
        ##print("overlap is: ",overlap)
        return tree2.write(format=9)
    tree1_copy = tree1.copy()
    tree2_copy = tree2.copy()
    tree1_copy.prune(overlap)
    tree2_copy.prune(overlap)

    #t.write(format=1
    splits2 = rf_dist_list.main(tree2_copy.copy(), tree1_copy.copy())
    splits1 = rf_dist_list.main(tree1_copy.copy(), tree2_copy.copy())
    ###print("splits 1:  ", splits1)
    ###print("splits 2:  ",splits2)
    for lists in splits1:
        node = tree1_copy.get_common_ancestor(lists[0])
        #subtree2=Tree()
        #parent=node.up
        node.delete()
        node = tree1_copy.get_common_ancestor(lists[1])
        node.delete()
    '''for lists in splits2:
        node=tree2_copy.get_common_ancestor(lists[0])
        #subtree2=Tree()
        #parent=node.up
        node.delete()
        node=tree2_copy.get_common_ancestor(lists[1])
        node.delete()'''
    remainder_tree1 = set(leaf_list1) - set(overlap)
    #for r in remainder_tree1:

    remainder_tree2 = set(leaf_list2) - set(overlap)
    root1 = tree1.get_tree_root()
    root1.add_features(descen=overlap)
    root2 = tree2.get_tree_root()
    root2.add_features(descen=overlap)

    for leaves in overlap:
        node = tree1.search_nodes(name=leaves)[0]
        while node.up:

            if hasattr(node, "descen"):
                #arr_descen=node.descen
                ###print("now here")
                node.descen.append(leaves)
                #node.descen=arr_descen
                ###print("now here", node.descen)
            else:
                ###print("here!")
                node.add_features(descen=[leaves])
            node = node.up
        node = tree2.search_nodes(name=leaves)[0]
        while node.up:

            if hasattr(node, "descen"):
                #arr_descen=node.descen
                ###print("now here")
                node.descen.append(leaves)
                #node.descen=arr_descen
                ###print("now here", node.descen)
            else:
                ###print("here!")
                node.add_features(descen=[leaves])
            node = node.up
    ###print("after pruning:  ", tree1_copy)
    for node in tree1.traverse("postorder"):
        if hasattr(node, "descen"):

            children = node.get_children()
            ###print(children)
            for c in children:

                if hasattr(c, "descen") is False:
                    ###print(node.descen)
                    ###print("-------------------------------------------------------")

                    if (len(node.descen) == 1):
                        new_node = tree1_copy.search_nodes(
                            name=node.descen[0])[0]
                        if (hasattr(new_node.up, "new_child")):

                            node_replace = new_node.up

                            node_replace.add_child(c)
                        else:

                            #new_node=tree1_copy.search_nodes(name=node.descen[0])[0]
                            node_replace = new_node.up
                            ###print("tree1",node_replace,c,node.descen)
                            node_change = node_replace.add_child()
                            node_change.add_child(c)
                            node_add_old = new_node.copy()
                            node_change.add_features(new_child=True)
                            node_change.add_child(node_add_old)
                            new_node.detach()
                            ###print(tree1_copy)
                        #new_node.add_sister(c)
                    else:
                        new_node = tree1_copy.get_common_ancestor(
                            set(node.descen))
                        ###print(new_node)
                        if (hasattr(new_node.up, "new_child")):
                            new_node = new_node.up
                            new_node.add_child(c)

                        elif (new_node.up):

                            node_change = new_node.up
                            ###print(new_node,c,node.descen)
                            node_change = node_change.add_child()
                            node_change.add_child(c)
                            node_add_old = new_node.copy()
                            node_change.add_features(new_child=True)
                            node_change.add_child(node_add_old)
                            new_node.detach()
                        else:
                            new_node = tree1_copy.get_tree_root()
                            if (hasattr(new_node, "new_root")):

                                new_node.add_child(c)
                            else:

                                root_copy = new_node.copy()
                                new_tree = Tree()
                                new_root = new_tree.get_tree_root()
                                new_root.add_features(new_root=True)
                                new_root.add_child(c)
                                new_root.add_child(root_copy)
                                tree1_copy = new_tree

    for node in tree2.traverse("postorder"):
        if hasattr(node, "descen"):

            children = node.get_children()
            ###print(children)
            for c in children:

                if hasattr(c, "descen") is False:
                    ###print(node.descen)
                    ###print("-------------------------------------------------------")

                    if (len(node.descen) == 1):
                        new_node = tree1_copy.search_nodes(
                            name=node.descen[0])[0]
                        if (hasattr(new_node.up, "new_child")):

                            node_replace = new_node.up

                            node_replace.add_child(c)
                        else:

                            #new_node=tree1_copy.search_nodes(name=node.descen[0])[0]
                            node_replace = new_node.up
                            ###print("tree1",node_replace,c,node.descen)
                            node_change = node_replace.add_child()
                            node_change.add_child(c)
                            node_add_old = new_node.copy()
                            node_change.add_features(new_child=True)
                            node_change.add_child(node_add_old)
                            new_node.detach()
                            ###print(tree1_copy)
                        #new_node.add_sister(c)
                    else:
                        new_node = tree1_copy.get_common_ancestor(
                            set(node.descen))
                        ###print(new_node)
                        if (hasattr(new_node.up, "new_child")):
                            new_node = new_node.up
                            new_node.add_child(c)

                        elif (new_node.up):

                            node_change = new_node.up
                            ###print(new_node,c,node.descen)
                            node_change = node_change.add_child()
                            node_change.add_child(c)
                            node_add_old = new_node.copy()
                            node_change.add_features(new_child=True)
                            node_change.add_child(node_add_old)
                            new_node.detach()
                        else:
                            new_node = tree1_copy.get_tree_root()
                            if (hasattr(new_node, "new_root")):

                                new_node.add_child(c)
                            else:

                                root_copy = new_node.copy()
                                new_tree = Tree()
                                new_root = new_tree.get_tree_root()
                                new_root.add_features(new_root=True)
                                new_root.add_child(c)
                                new_root.add_child(root_copy)
                                tree1_copy = new_tree

    ###print("returning tree",tree1_copy)
    return tree1_copy.write(format=9)
Example #31
0
def simulate_sequence_data(kappa, omega0, omega1, omega2, omega0_weight,
                           omega1_weight, selection_intensity_parameter,
                           true_history_path, output_dir, num_of_replicates,
                           aln_len, nuc1_theta, nuc1_theta1, nuc1_theta2,
                           nuc2_theta, nuc2_theta1, nuc2_theta2, nuc3_theta,
                           nuc3_theta1, nuc3_theta2):
    # prepare directory for simulation output
    sequence_output_dir = output_dir + "sequence_data/"  # _kappa_" + str(kappa) + "_omega0_" + str(omega0) + "_omega1_" + str(omega1) + "_omega2_" + str(omega2) + "_theta1_" + str(omega0_weight) + "_theta2_" + str(omega1_weight/(1-omega0_weight)) + "/"
    if not os.path.exists(sequence_output_dir):
        res = os.system("mkdir -p " + sequence_output_dir)
    control_file_path = sequence_output_dir + "control.txt"

    # read the character history and derive from it a tree and a labeling in INDELible control file compatible format
    true_history = Tree(true_history_path, format=1)
    label_regex = re.compile("\{(.*?)\}", re.DOTALL)
    node_to_branch_length_expression = dict(
    )  # will contain the node name is well if the node is internal, as internal node names should be omitted from the tree string
    node_to_label = dict()
    node_to_branch_index = dict(
    )  # will help in creation of branch labeling in relax parameters file
    i = 0
    for node in true_history.traverse("postorder"):
        if not node.is_root():
            node_label = label_regex.search(node.name).group(1)
            node_name = node.name.replace("{" + node_label + "}", "")
            node.name = node_name  # update the name of the node to exclude the label
            if node_label == "0":
                node_to_label[node_name] = "BG"
            else:
                node_to_label[node_name] = "FG"
            node_to_branch_index[node_name] = i
            i += 1
            node_dist = node.dist
            if int(node.dist) == node.dist:
                node_dist = int(node.dist)
            node_to_branch_length_expression[
                node_name] = node_name + ":" + str(node_dist)
    tree_str = true_history.write(
        outfile=None, format=5
    )  # get a newick representation of the tree without internal nodes names
    # fix tree str number formatting
    tree_str = fix_tree_str_format(tree_str)
    tree_labels_str = true_history.write(
        outfile=None, format=1
    )  # get a newick representation of the tree with internal nodes names
    bpp_bg_labels_str = "model1.nodes_id = "
    bpp_fg_labels_str = "model2.nodes_id = "
    for node_name in node_to_label.keys():
        before = tree_labels_str
        node = true_history.search_nodes(name=node_name)[0]
        if node.is_leaf():
            tree_labels_str = tree_labels_str.replace(
                node_to_branch_length_expression[node_name],
                node_name + " #" + node_to_label[node_name])
        else:
            tree_labels_str = tree_labels_str.replace(
                node_to_branch_length_expression[node_name],
                " #" + node_to_label[node_name])
        after = tree_labels_str
        if (before == after):
            print("\nfailed to replace expression")
            print("node name: ", node_name)
            print("branch length expression: ",
                  node_to_branch_length_expression[node_name])
            print("node label expression: ", " #" + node_to_label[node_name])
        if node_to_label[node_name] == "BG":
            bpp_bg_labels_str = bpp_bg_labels_str + str(
                node_to_branch_index[node_name]) + ","
        else:
            bpp_fg_labels_str = bpp_fg_labels_str + str(
                node_to_branch_index[node_name]) + ","
    # set the label of the root as the label of one of its immediate sons
    root = true_history.get_tree_root()
    root_son = root.get_children()[0]
    son_label = node_to_label[root_son.name]
    tree_labels_str = tree_labels_str.replace(";", "#" + son_label + ";")
    labels_str = bpp_bg_labels_str[:-1] + "\n" + bpp_fg_labels_str[:-1]

    # compute codon frequencies
    codon_to_frequency = compute_codon_frequencies(nuc1_theta, nuc1_theta1,
                                                   nuc1_theta2, nuc2_theta,
                                                   nuc2_theta1, nuc2_theta2,
                                                   nuc3_theta, nuc3_theta1,
                                                   nuc3_theta2)

    # create control file
    control_file_template = '''[TYPE] CODON 1                    

[SETTINGS]
    [ancestralprint]    FALSE
	[fastaextension]    fas
	[output]          	FASTA
	[fileperrep]      	TRUE
    [printrates]        TRUE

  /* Notice that the number of classes and proportions do not change below. */

                   // Kap p0  p1  w0  w1  w2     (p2=1-p1-p0=0.5) 
[MODEL] BG [submodel] <kappa> <omega0_weight> <omega1_weight> <bg_omega0> <bg_omega1> <bg_omega2>
[MODEL] FG [submodel] <kappa> <omega0_weight> <omega1_weight> <fg_omega0> <fg_omega1> <fg_omega2>
    [statefreq]  
          <TTT_freq> <TTC_freq> <TTA_freq> <TTG_freq>    //  TTT  TTC  TTA  TTG
          <TCT_freq> <TCC_freq> <TCA_freq> <TCG_freq>    //  TCT  TCC  TCA  TCG 
          <TAT_freq> <TAC_freq> <TAA_freq> <TAG_freq>    //  TAT  TAC  TAA  TAG
          <TGT_freq> <TGC_freq> <TGA_freq> <TGG_freq>    //  TGT  TGC  TGA  TGG

          <CTT_freq> <CTC_freq> <CTA_freq> <CTG_freq>    //  CTT  CTC  CTA  CTG 
          <CCT_freq> <CCC_freq> <CCA_freq> <CCG_freq>    //  CCT  CCC  CCA  CCG 
          <CAT_freq> <CAC_freq> <CAA_freq> <CAG_freq>    //  CAT  CAC  CAA  CAG 
          <CGT_freq> <CGC_freq> <CGA_freq> <CGG_freq>    //  CGT  CGC  CGA  CGG 

          <ATT_freq> <ATC_freq> <ATA_freq> <ATG_freq>    //  ATT  ATC  ATA  ATG  
          <ACT_freq> <ACC_freq> <ACA_freq> <ACG_freq>    //  ACT  ACC  ACA  ACG  
          <AAT_freq> <AAC_freq> <AAA_freq> <AAG_freq>    //  AAT  AAC  AAA  AAG  
          <AGT_freq> <AGC_freq> <AGA_freq> <AGG_freq>    //  AGT  AGC  AGA  AGG 

          <GTT_freq> <GTC_freq> <GTA_freq> <GTG_freq>    //  GTT  GTC  GTA  GTG 
          <GCT_freq> <GCC_freq> <GCA_freq> <GCG_freq>    //  GCT  GCC  GCA  GCG  
          <GAT_freq> <GAC_freq> <GAA_freq> <GAG_freq>    //  GAT  GAC  GAA  GAG  
          <GGT_freq> <GGC_freq> <GGA_freq> <GGG_freq>    //  GGT  GGC  GGA  GGG
  /* 
     Like before, to get a correctly formatted [BRANCHES] block from a [TREE] block
     simply cut and paste the tree and change the branch lengths to model names.
     The stationary frequencies of the model at the root are used to generate the
     root sequence and this model defines the number of site categories used.
  */


[TREE]     t1  <tree_str>

[BRANCHES] b1  <tree_labels_str>


[PARTITIONS] Pname  [t1 b1 <aln_len>]    // tree t1, branchclass b1, root length 1000

[EVOLVE]     Pname  <num_of_replicates>  sequence_data  // 10 replicates generated from partition Pname'''
    control_file_content = control_file_template.replace(
        "<kappa>", str("%.15f" % kappa))
    control_file_content = control_file_content.replace(
        "<bg_omega0>", str("%.15f" % omega0))
    control_file_content = control_file_content.replace(
        "<bg_omega1>", str("%.15f" % omega1))
    control_file_content = control_file_content.replace(
        "<bg_omega2>", str("%.15f" % omega2))
    control_file_content = control_file_content.replace(
        "<omega0_weight>", str("%.15f" % omega0_weight))
    control_file_content = control_file_content.replace(
        "<omega1_weight>", str("%.15f" % omega1_weight))
    control_file_content = control_file_content.replace(
        "<fg_omega0>", str("%.15f" % omega0**selection_intensity_parameter))
    control_file_content = control_file_content.replace(
        "<fg_omega1>", str("%.15f" % omega1**selection_intensity_parameter))
    control_file_content = control_file_content.replace(
        "<fg_omega2>", str("%.15f" % omega2**selection_intensity_parameter))
    control_file_content = control_file_content.replace("<tree_str>", tree_str)
    control_file_content = control_file_content.replace(
        "<tree_labels_str>", tree_labels_str)
    control_file_content = control_file_content.replace(
        "<num_of_replicates>", str(num_of_replicates))
    control_file_content = control_file_content.replace(
        "<aln_len>", str(aln_len))
    control_file_content = control_file_content.replace(
        "<TTT_freq>", str(codon_to_frequency["TTT"]))
    control_file_content = control_file_content.replace(
        "<TTC_freq>", str(codon_to_frequency["TTC"]))
    control_file_content = control_file_content.replace(
        "<TTA_freq>", str(codon_to_frequency["TTA"]))
    control_file_content = control_file_content.replace(
        "<TTG_freq>", str(codon_to_frequency["TTG"]))
    control_file_content = control_file_content.replace(
        "<TCT_freq>", str(codon_to_frequency["TCT"]))
    control_file_content = control_file_content.replace(
        "<TCC_freq>", str(codon_to_frequency["TCC"]))
    control_file_content = control_file_content.replace(
        "<TCA_freq>", str(codon_to_frequency["TCA"]))
    control_file_content = control_file_content.replace(
        "<TCG_freq>", str(codon_to_frequency["TCG"]))
    control_file_content = control_file_content.replace(
        "<TAT_freq>", str(codon_to_frequency["TAT"]))
    control_file_content = control_file_content.replace(
        "<TAC_freq>", str(codon_to_frequency["TAC"]))
    control_file_content = control_file_content.replace(
        "<TAA_freq>", str(codon_to_frequency["TAA"]))
    control_file_content = control_file_content.replace(
        "<TAG_freq>", str(codon_to_frequency["TAG"]))
    control_file_content = control_file_content.replace(
        "<TGT_freq>", str(codon_to_frequency["TGT"]))
    control_file_content = control_file_content.replace(
        "<TGC_freq>", str(codon_to_frequency["TGC"]))
    control_file_content = control_file_content.replace(
        "<TGA_freq>", str(codon_to_frequency["TGA"]))
    control_file_content = control_file_content.replace(
        "<TGG_freq>", str(codon_to_frequency["TGG"]))
    control_file_content = control_file_content.replace(
        "<CTT_freq>", str(codon_to_frequency["CTT"]))
    control_file_content = control_file_content.replace(
        "<CTC_freq>", str(codon_to_frequency["CTC"]))
    control_file_content = control_file_content.replace(
        "<CTA_freq>", str(codon_to_frequency["CTA"]))
    control_file_content = control_file_content.replace(
        "<CTG_freq>", str(codon_to_frequency["CTG"]))
    control_file_content = control_file_content.replace(
        "<CCT_freq>", str(codon_to_frequency["CCT"]))
    control_file_content = control_file_content.replace(
        "<CCC_freq>", str(codon_to_frequency["CCC"]))
    control_file_content = control_file_content.replace(
        "<CCA_freq>", str(codon_to_frequency["CCA"]))
    control_file_content = control_file_content.replace(
        "<CCG_freq>", str(codon_to_frequency["CCG"]))
    control_file_content = control_file_content.replace(
        "<CAT_freq>", str(codon_to_frequency["CAT"]))
    control_file_content = control_file_content.replace(
        "<CAC_freq>", str(codon_to_frequency["CAC"]))
    control_file_content = control_file_content.replace(
        "<CAA_freq>", str(codon_to_frequency["CAA"]))
    control_file_content = control_file_content.replace(
        "<CAG_freq>", str(codon_to_frequency["CAG"]))
    control_file_content = control_file_content.replace(
        "<CGT_freq>", str(codon_to_frequency["CGT"]))
    control_file_content = control_file_content.replace(
        "<CGC_freq>", str(codon_to_frequency["CGC"]))
    control_file_content = control_file_content.replace(
        "<CGA_freq>", str(codon_to_frequency["CGA"]))
    control_file_content = control_file_content.replace(
        "<CGG_freq>", str(codon_to_frequency["CGG"]))
    control_file_content = control_file_content.replace(
        "<ATT_freq>", str(codon_to_frequency["ATT"]))
    control_file_content = control_file_content.replace(
        "<ATC_freq>", str(codon_to_frequency["ATC"]))
    control_file_content = control_file_content.replace(
        "<ATA_freq>", str(codon_to_frequency["ATA"]))
    control_file_content = control_file_content.replace(
        "<ATG_freq>", str(codon_to_frequency["ATG"]))
    control_file_content = control_file_content.replace(
        "<ACT_freq>", str(codon_to_frequency["ACT"]))
    control_file_content = control_file_content.replace(
        "<ACC_freq>", str(codon_to_frequency["ACC"]))
    control_file_content = control_file_content.replace(
        "<ACA_freq>", str(codon_to_frequency["ACA"]))
    control_file_content = control_file_content.replace(
        "<ACG_freq>", str(codon_to_frequency["ACG"]))
    control_file_content = control_file_content.replace(
        "<AAT_freq>", str(codon_to_frequency["AAT"]))
    control_file_content = control_file_content.replace(
        "<AAC_freq>", str(codon_to_frequency["AAC"]))
    control_file_content = control_file_content.replace(
        "<AAA_freq>", str(codon_to_frequency["AAA"]))
    control_file_content = control_file_content.replace(
        "<AAG_freq>", str(codon_to_frequency["AAG"]))
    control_file_content = control_file_content.replace(
        "<AGT_freq>", str(codon_to_frequency["AGT"]))
    control_file_content = control_file_content.replace(
        "<AGC_freq>", str(codon_to_frequency["AGC"]))
    control_file_content = control_file_content.replace(
        "<AGA_freq>", str(codon_to_frequency["AGA"]))
    control_file_content = control_file_content.replace(
        "<AGG_freq>", str(codon_to_frequency["AGG"]))
    control_file_content = control_file_content.replace(
        "<GTT_freq>", str(codon_to_frequency["GTT"]))
    control_file_content = control_file_content.replace(
        "<GTC_freq>", str(codon_to_frequency["GTC"]))
    control_file_content = control_file_content.replace(
        "<GTA_freq>", str(codon_to_frequency["GTA"]))
    control_file_content = control_file_content.replace(
        "<GTG_freq>", str(codon_to_frequency["GTG"]))
    control_file_content = control_file_content.replace(
        "<GCT_freq>", str(codon_to_frequency["GCT"]))
    control_file_content = control_file_content.replace(
        "<GCC_freq>", str(codon_to_frequency["GCC"]))
    control_file_content = control_file_content.replace(
        "<GCA_freq>", str(codon_to_frequency["GCA"]))
    control_file_content = control_file_content.replace(
        "<GCG_freq>", str(codon_to_frequency["GCG"]))
    control_file_content = control_file_content.replace(
        "<GAT_freq>", str(codon_to_frequency["GAT"]))
    control_file_content = control_file_content.replace(
        "<GAC_freq>", str(codon_to_frequency["GAC"]))
    control_file_content = control_file_content.replace(
        "<GAA_freq>", str(codon_to_frequency["GAA"]))
    control_file_content = control_file_content.replace(
        "<GAG_freq>", str(codon_to_frequency["GAG"]))
    control_file_content = control_file_content.replace(
        "<GGT_freq>", str(codon_to_frequency["GGT"]))
    control_file_content = control_file_content.replace(
        "<GGC_freq>", str(codon_to_frequency["GGC"]))
    control_file_content = control_file_content.replace(
        "<GGA_freq>", str(codon_to_frequency["GGA"]))
    control_file_content = control_file_content.replace(
        "<GGG_freq>", str(codon_to_frequency["GGG"]))
    with open(control_file_path, "w") as control_file:
        control_file.write(control_file_content)

    # execute INDELible
    res = os.chdir(sequence_output_dir)
    res = os.chdir(sequence_output_dir)
    res = os.system(
        "/groups/itay_mayrose/halabikeren/programs/indelible/INDELibleV1.03/src/indelible"
    )

    # check if the simulation is done, and sleep until done
    while not os.path.exists(sequence_output_dir + "sequence_data_1.fas"):
        sleep(3)

    # remove spaces from file names and rename files
    res = os.system("rm -r " + sequence_output_dir + "LOG.txt")
    os.chdir(sequence_output_dir)
    res = os.system("rm -r " + sequence_output_dir +
                    "sequence_data_TRUE_1.fas")
    remove_spaces(sequence_output_dir + "sequence_data_1.fas")

    return sequence_output_dir + "sequence_data_1.fas", labels_str
Example #32
0
    def tmrca_graph(self,
                    sites_to_newick_mappings,
                    labels,
                    topology_only=False,
                    subplotPosition=111):
        """
            Plots a line graph comparing tree heights from different MS files.

            Inputs:
                i. sites_to_newick_mappings -- a list of the mappings outputted by sites_to_newick_ms()
                ii. topology_only: If set to True, distance between nodes will be referred to the number of nodes between them.
                    In other words, topological distance will be used instead of branch length distances.

            Returns:
                i. A line graph with the tree height as the y-axis and the site number as the x-axis.
        """

        print labels

        ax = plt.subplot(subplotPosition)

        ax.set_title('TMRCA Line Graph')
        ax.set_xlabel('SNP Site Number')
        ax.set_ylabel('TMRCA')

        trees = []
        roots = []
        leaves = []
        dist = []
        heights = []

        # iterate over each mapping in list
        for i in range(len(sites_to_newick_mappings)):
            mapping = sites_to_newick_mappings[i]
            for tree in mapping:
                # iterate over mapping to get trees
                trees.append(mapping[tree])

            for j in range(len(trees)):
                # get tree roots
                roots.append(Tree.get_tree_root(Tree(trees[j])))

                # get distance from roots to farthest leaves
                leaves.append(
                    TreeNode.get_farthest_leaf(roots[j], topology_only))

            for k in range(len(leaves)):
                # regular expression to get height values from list of farthest leaves
                dist.append(re.findall(', \d{1,}.\d{1,}', str(leaves[k])))

                # format with regular expression to remove unnecessary tokens
                heights.append(re.sub("\[', |']", '', str(dist[k])))

            # resets ind to prevent index error in linestyle pattern
            # if i > 3:
            #     ind = random.randint(0, 3)
            # else:
            #     ind = i

            # plot line graph
            ax.plot(sites_to_newick_mappings[0].keys(),
                    heights,
                    c=self.COLORS[i],
                    linestyle=self.PATTERNS[i % len(self.PATTERNS)],
                    label=labels[i])

            # clear lists
            trees = []
            roots = []
            leaves = []
            dist = []
            heights = []

        leg = ax.legend()
        if leg:
            leg.draggable()

        return ax
Example #33
0
class UniFrac(object):
    """
    note that the whole association between metadata and leave nodes works by .loc:
    the nodes are named according to the dataframe index and we look up a nodes metadata with df_metadata.loc[node.name]
    """
    def __init__(self, datamatrix, df_metadata):
        super(UniFrac, self).__init__()

        "make sure that the dataframe index is unique"
        assert len(df_metadata) == len(
            set(df_metadata.index)
        ), 'row-index is not unique, but we need uniqueness to associate the metadata with the leaves in the tree'

        self.datamatrix = datamatrix
        self.df_metadata = df_metadata
        self.tree = None
        self._linkage = None  # just kept to do the cut_trees call
        self.cluster_roots = None  # just kept to do the cut_trees call
        self.nodes2leaves = None  # for caching leaf lookups, however this return a set!!

    def _update_leave_metadata(self):
        "puts the metadata in self.metadata as features into the trees leaves"
        assert self.tree

        # to speed things up, query the dataframe only once
        leaves = self.tree.get_leaves()
        leavenames = [leave.name for leave in leaves]
        meta = self.df_metadata.loc[
            leavenames].values  # sorts the metadata in the same order as leavenames
        featurenames = self.df_metadata.columns.values
        for i, leaf in enumerate(leaves):
            leaf.add_features(**dict(zip(featurenames, meta[i, :])))
            #TODO not sure if this overwrites previous features (thats what i want) or just adds additional features!

    def build_tree(self, method, metric):
        """
        constructs the hierarchical clustering tree, 
        but no clustering (corresponding to some tree pruning) in here
        """
        self._linkage = linkage(self.datamatrix, method=method, metric=metric)
        # turn it into a ete tree
        leave_labels = self.df_metadata.index.values
        newick_tree = linkage_to_newick(self._linkage, labels=leave_labels)
        self.tree = Tree(newick_tree)
        self.nodes2leaves = self.tree.get_cached_content(
        )  # makes it easy to lookup leaves of a node
        # populate the leaves with metadatqa
        self._update_leave_metadata()

    def cluster(self, n_clusters):
        """
        prunes the hierarchical clustering tree to get clusters of data
        this clustering is also added to the metadata
        also adds the self.cluster_roots (caching it, we need it in unifrac calls)
        """
        assert self.tree
        clustering_prune = cut_tree(self._linkage, n_clusters)
        self.df_metadata['clustering'] = clustering_prune
        self._update_leave_metadata()
        self.cluster_roots = find_cluster_roots(self.tree)

        for i, cluster_root in enumerate(self.cluster_roots):
            cluster_root.add_features(
                **{
                    'is_cluster_root': i,
                    'n_datapoints': len(self.nodes2leaves[cluster_root])
                })

    def unifrac_distance(self, group1, group2, randomization=None):
        """
        calculates the uniFrac distance of the two sample-groups
        group1: list of nodenames (i.e. indices of the metadata)
        group2: ---"--- 
        randomization: (int) how many times to compute the 'randomized' uniFrac distance to get a pvalue
        """
        assert 'clustering' in self.df_metadata.columns and self.cluster_roots, "run cluster() first"

        # all_leaves = self.tree.get_leaves()  # TODO this is a performance hog
        the_Root = self.tree.get_tree_root()
        all_leaves = self.nodes2leaves[
            the_Root]  # for performance reasons this is better then the line above

        # make sure all group elements are in hte tree
        leaf_names = [_.name for _ in all_leaves]
        assert all([_ in leaf_names for _ in group1])
        assert all([_ in leaf_names for _ in group2])

        # t.get_leaves_by_name(group1)
        group1_nodes = set([_ for _ in all_leaves if _.name in group1
                            ])  # sets for faster `in` lookup
        group2_nodes = set([_ for _ in all_leaves if _.name in group2
                            ])  # TODO replace by search_nodes?!

        the_distance = self._unifrac_dist(group1_nodes, group2_nodes)

        if randomization and randomization > 0:
            G1 = len(group1_nodes)
            G2 = len(group2_nodes)
            all_nodes = list(
                group1_nodes |
                group2_nodes)  # union, but turn into list for partioning later
            randomized_distances = []

            for i in range(randomization):
                shuffle(all_nodes)  # inplace shuffle
                group1_nodes_random = set(all_nodes[:G1])
                group2_nodes_random = set(all_nodes[G1:])
                randomized_distances.append(
                    self._unifrac_dist(group1_nodes_random,
                                       group2_nodes_random))

            randomized_distances = np.array(randomized_distances)

            # pvalue
            p = 1 - stats.norm(
                loc=randomized_distances.mean(-1),
                scale=randomized_distances.std(-1)).cdf(the_distance)
            p2 = np.sum(randomized_distances > the_distance) / len(
                randomized_distances)
            # print(p, p2)
            return the_distance, randomized_distances, p2
        else:
            return the_distance

    def _unifrac_dist(self, group1_nodes, group2_nodes):
        "given two node lists, calculate the unifrac distance"
        At, Bt = len(group1_nodes), len(group2_nodes)
        nom = {}
        denom = {}
        for i, current_cluster_root in enumerate(self.cluster_roots):
            leafs = list(self.nodes2leaves[current_cluster_root]
                         )  # all the datapoitns in the cluster
            Ai = len([_ for _ in leafs if _ in group1_nodes])
            Bi = len([_ for _ in leafs if _ in group2_nodes])
            distance2root = current_cluster_root.distance2root  # cached already
            nom[i] = distance2root * np.abs(Ai / At - Bi / Bt)
            denom[i] = distance2root * np.abs(Ai / At + Bi / Bt)

        n_clusters = len(nom)
        summed_nom = sum([nom[i] for i in range(n_clusters)])
        summed_denom = sum([denom[i] for i in range(n_clusters)])
        unifrac_distance = summed_nom / summed_denom
        return unifrac_distance

    def visualize(self, group1=None, group2=None):
        import matplotlib
        import matplotlib.pyplot as plt

        # annotate the cluster roots with their fractions
        if group1 or group2:
            for i, cluster_root in enumerate(self.cluster_roots):
                # count downstream conditions in the leafs
                datapoints_in_cluster = list(self.nodes2leaves[cluster_root])
                cluster_root.add_face(
                    TextFace(f"Group1: {len(group1)}// Group2:{len(group2)}"),
                    column=0,
                    position="branch-right")

        def _custom_layout(node):
            cmap_cluster = plt.cm.tab10(
                np.linspace(0, 1, len(self.cluster_roots)))
            cmap_treated = plt.cm.viridis(np.linspace(0, 1, 2))

            if node.is_leaf():
                c_cluster = matplotlib.colors.rgb2hex(
                    cmap_cluster[node.clustering, :])
                c_treat = matplotlib.colors.rgb2hex(
                    cmap_treated[node.treated, :])
                node.img_style["fgcolor"] = c_treat
                node.img_style["bgcolor"] = c_cluster

            if 'is_cluster_root' in node.features:
                c_cluster = matplotlib.colors.rgb2hex(
                    cmap_cluster[node.is_cluster_root, :])
                node.img_style["bgcolor"] = c_cluster
                node.img_style["draw_descendants"] = False
                node.add_face(TextFace(f"#data:{node.n_datapoints}"),
                              column=0,
                              position="branch-right")

        ts = TreeStyle()
        ts.mode = "r"
        ts.show_leaf_name = False
        ts.arc_start = -180  # 0 degrees = 3 o'clock
        ts.arc_span = 270
        ts.layout_fn = _custom_layout
        self.tree.show(tree_style=ts)
from ete3 import Tree
taxonomy = Tree("skills_taxonomy_tree_1.nw")

#find the depth of the tree
node, depth = taxonomy.get_farthest_leaf()
print(depth)

root = taxonomy.get_tree_root()
depth = int(depth)
for i in range(0, depth + 1):
    for node in taxonomy.traverse("postorder"):
        j = i
        j = float(j)
        if node.get_distance(node, root) == j:
            l = j / depth
            node.add_feature('level_score', l)
Example #35
0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys, os, argparse
from ete3 import Tree

parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tree')
parser.add_argument('-o', '--outgroup')
opts = parser.parse_args(sys.argv[1:])

# check arguments
if not os.path.isfile(opts.tree):
	sys.stderr.write("File %s not found", opts.tree)
	sys.exit(1)

try:
	opts.outgroup
except NameError:
	sys.stderr.write("Outgroup must be defined (--outgroup)")
	sys.exit(1)


t = Tree(opts.tree)
t.set_outgroup(opts.outgroup)
print(t.get_tree_root().write(format=5))