Example #1
0
def run_rescale(prefix, tree, data, n_proc=5):
    branches = {}
    cnt = 0
    for phy, weights, asc, invariants in data:
        cnt += sum(invariants.values())
        for fname in glob.glob('RAxML_*.{0}'.format(prefix)):
            os.unlink(fname)
        if asc is None:
            cmd = '{0} -m GTR{4} -n {1} -t {7} -f e -D -s {2} -a {3} -T {5} -p {6} --no-bfgs'.format(
                raxml, prefix, phy, weights, 'GAMMA', n_proc, rint, tree)
        else:
            cmd = '{0} -m ASC_GTR{5} -n {1} -t {8} -f e -D -s {2} -a {3} -T {6} -p {7} --asc-corr stamatakis --no-bfgs -q {4}'.format(
                raxml, prefix, phy, weights, asc, 'GAMMA', n_proc, rint, tree)
        run = Popen(cmd.split())
        run.communicate()

        tre = Tree('RAxML_result.{0}'.format(prefix), format=0)
        with open(phy + '.subtree', 'w') as fout:
            fout.write(tre.write(format=0) + '\n')
        for node in tre.get_descendants('postorder'):
            if node.is_leaf():
                node.d = [node.name]
            else:
                node.d = [n for c in node.children for n in c.d]
            key = tuple(sorted(node.d))
            if key not in branches:
                branches[key] = [node.dist]
            else:
                branches[key].append(node.dist)

        for fn in glob.glob('RAxML_*.{0}'.format(prefix)) + [
                phy, phy + '.reduced', weights, asc
        ]:
            try:
                os.unlink(fn)
            except:
                pass

    tre = Tree(tree, format=1)
    leaves = set(tre.get_leaf_names())
    for node in tre.get_descendants('postorder'):
        if node.is_leaf():
            node.d = [node.name]
        else:
            node.d = [n for c in node.children for n in c.d]
        key1 = tuple(sorted(node.d))
        key2 = tuple(sorted(leaves - set(node.d)))
        if key1 in branches:
            node.dist = np.mean(branches[key1])
        elif key2 in branches:
            node.dist = np.mean(branches[key2])
        else:
            node.dist = 0.
        if -0.5 < node.dist * cnt < 0.5:
            node.dist = 0.0

    fname = '{0}.unrooted.nwk'.format(prefix)
    tre.write(outfile=fname, format=0)
    return fname
Example #2
0
class NTBDB(object):
    def __init__(self, imgs_dir='/storage/imgs/low-res', metadata_dir='/storage/metadata'):
        self.metadata_dir = metadata_dir
        self.imgs_dir = imgs_dir
        with open(os.path.join(self.metadata_dir, 'metadata.pickle')) as md:
            self.metadata = pickle.load(md)
        for img in EXCLUDE_PICS:
            del self.metadata[img]
        self.by_tag = dict()
        for p in self.metadata.itervalues():
            for tag in p['tags']:
                self.by_tag.setdefault(tag, []).append(p)

        self.tags = Tree(os.path.join(self.metadata_dir, 'tags.nw'), format=8)
        self.tag_by_name = {tag.name: tag for tag in self.tags.get_descendants()}
     
    def by_tag_with_children(self, tag_name):
        tag_node = self.tags.search_nodes(name=tag_name)[0]
        all_tags = [tag_node]
        all_tags.extend(tag_node.get_descendants())
        return list(itertools.chain.from_iterable(self.by_tag.get(tag.name, []) for tag in all_tags))
    
    def tag_score(self, tag):
        return len(self.by_tag.get(tag.name, [])) + sum(map(self.tag_score, tag.children))
    
    def top_tags(self, max_children=5, max_depth=2):
        top_tags = self.tags.copy()
        for n in top_tags.traverse():
            n.children = sorted(n.children, key=self.tag_score, reverse=True)[:max_children]
            if n.get_distance(n.get_tree_root()) > max_depth - 1:
                n.children = []
        return top_tags

    def image_path(self, image_index):
        return os.path.join(self.imgs_dir, self.metadata[image_index]['folder'], self.metadata[image_index]['filename'] + '.jpg')
Example #3
0
def read_prunned_data():
    if not os.path.exists(INPUT_DATA_FILE):
        save_data_to_file(INPUT_DATA_FILE)
    t = Tree(INPUT_DATA_FILE, format=1)
    descendants = random.choices(t.get_descendants(), k=NUMBER)
    # filter empty and int nodes
    descendants = [
        d.name for d in descendants if d.name and not d.name.isdecimal()
    ]
    t.prune(descendants)
    return t
def collapse_nodes(inputfile, threshold, output_name):
    """
    Collapse nodes less than threshold
    :param inputfile: The tree file
    :param threshold: The threshold on which to collapse
    :param output_name: The output file name
    :return:
    """
    input_tree = Tree(inputfile)

    for node in input_tree.get_descendants():
        if node.support < threshold and node.is_leaf() == False:
            node.delete(preserve_branch_length=True)

    Tree.write(input_tree, outfile=output_name)
Example #5
0
from ete3 import Tree
tree = Tree('(A:1,(B:1,(C:1,D:1):0.5):0.5);')
# Prints the name of every leaf under the tree root
print "Leaf names:"
for leaf in tree.get_leaves():
    print leaf.name
# Label nodes as terminal or internal. If internal, saves also the
# number of leaves that it contains.
print "Labeled tree:"
for node in tree.get_descendants():
    if node.is_leaf():
        node.add_features(ntype="terminal")
    else:
        node.add_features(ntype="internal", size=len(node))
# Gets the extended newick of the tree including new node features
print tree.write(features=[])
Example #6
0
class um_tree:
    def __init__(self, tree):
        self.tree = Tree(tree, format=1)
        self.tree.resolve_polytomy(default_dist=0.000001, recursive=True)
        self.tree.dist = 0
        self.tree.add_feature("age", 0)
        self.nodes = self.tree.get_descendants()
        internal_node = []
        cnt = 0
        for n in self.nodes:
            node_age = n.get_distance(self.tree)
            n.add_feature("age", node_age)
            if not n.is_leaf():
                n.add_feature("id", cnt)
                cnt = cnt + 1
                internal_node.append(n)
        self.nodes = internal_node
        one_leaf = self.tree.get_farthest_node()[0]
        one_leaf.add_feature("id", cnt + 1)
        if one_leaf.is_leaf():
            self.nodes.append(one_leaf)
        self.nodes.sort(key=self.__compare_node)
        self.species_list = []
        self.coa_roots = None

    def __compare_node(self, node):
        return node.age

    def get_waiting_times(self, threshold_node=None, threshold_node_idx=0):
        wt_list = []
        reach_t = False
        curr_age = 0.0
        curr_spe = 2
        curr_num_coa = 0
        coa_roots = []
        min_brl = 1000
        num_spe = -1

        if threshold_node == None:
            threshold_node = self.nodes[threshold_node_idx]

        last_coa_num = 0
        tcnt = 0
        for node in self.nodes:
            num_children = len(node.get_children())
            wt = None
            times = node.age - curr_age
            if times >= 0:
                if times < min_brl and times > 0:
                    min_brl = times
                curr_age = node.age
                assert curr_spe >= 0

                if reach_t:
                    if tcnt == 0:
                        last_coa_num = 2
                    fnode = node.up
                    coa_root = None

                    idx = 0

                    while not fnode.is_root():
                        idx = 0
                        for coa_r in coa_roots:
                            if coa_r.id == fnode.id:
                                coa_root = coa_r
                                break
                                idx = idx + 1
                        if coa_root != None:
                            break
                        else:

                            fnode = fnode.up

                    wt = waiting_time(length=times,
                                      num_coas=curr_num_coa,
                                      num_lines=curr_spe)

                    for coa_r in coa_roots:
                        coa = coalescent(num_individual=coa_r.curr_n)
                        wt.coas.add_coalescent(coa)

                    wt.coas.coas_idx = last_coa_num
                    wt.num_curr_coa = last_coa_num
                    if coa_root == None:  #here can be modified to use multiple T
                        curr_spe = curr_spe - 1
                        curr_num_coa = curr_num_coa + 1
                        node.add_feature("curr_n", 2)
                        coa_roots.append(node)
                        last_coa_num = 2
                    else:
                        curr_n = coa_root.curr_n
                        coa_root.add_feature("curr_n", curr_n + 1)
                        last_coa_num = curr_n + 1
                    tcnt = tcnt + 1

                else:
                    if node.id == threshold_node.id:
                        reach_t = True
                        tcnt = 0
                        wt = waiting_time(length=times,
                                          num_coas=0,
                                          num_lines=curr_spe)
                        num_spe = curr_spe
                        curr_spe = curr_spe - 1
                        curr_num_coa = 2
                        node.add_feature("curr_n", 2)
                        coa_roots.append(node)
                    else:
                        wt = waiting_time(length=times,
                                          num_coas=0,
                                          num_lines=curr_spe)
                        curr_spe = curr_spe + 1
            if times > 0.00000001:

                wt_list.append(wt)

        for wt in wt_list:
            wt.count_num_lines()

        self.species_list = []
        all_coa_leaves = []
        self.coa_roots = coa_roots
        for coa_r in coa_roots:
            leaves = coa_r.get_leaves()
            all_coa_leaves.extend(leaves)
            self.species_list.append(leaves)

        all_leaves = self.tree.get_leaves()
        for leaf in all_leaves:
            if leaf not in all_coa_leaves:
                self.species_list.append([leaf])

        return wt_list, num_spe

    def show(self, wt_list):
        cnt = 1
        for wt in wt_list:
            print("Waitting interval " + repr(cnt))
            print(wt)
            cnt = cnt + 1

    def get_species(self):
        sp_list = []
        for sp in self.species_list:
            spe = []
            for taxa in sp:
                spe.append(taxa.name)
            sp_list.append(spe)

        all_taxa_name = []

        for leaf in self.tree.get_leaves():
            all_taxa_name.append(leaf.name)

        style0 = NodeStyle()
        style0["fgcolor"] = "#000000"
        style0["vt_line_color"] = "#0000aa"
        style0["hz_line_color"] = "#0000aa"
        style0["vt_line_width"] = 2
        style0["hz_line_width"] = 2
        style0["vt_line_type"] = 0  # 0 solid, 1 dashed, 2 dotted
        style0["hz_line_type"] = 0
        style0["size"] = 0
        for node in self.tree.get_descendants():
            node.set_style(style0)
            node.img_style["size"] = 0
        self.tree.set_style(style0)
        self.tree.img_style["size"] = 0
        style1 = NodeStyle()
        style1["fgcolor"] = "#000000"
        style1["vt_line_color"] = "#ff0000"
        style1["hz_line_color"] = "#0000aa"
        style1["vt_line_width"] = 2
        style1["hz_line_width"] = 2
        style1["vt_line_type"] = 0  # 0 solid, 1 dashed, 2 dotted
        style1["hz_line_type"] = 0
        style1["size"] = 0
        style2 = NodeStyle()
        style2["fgcolor"] = "#0f0f0f"
        style2["vt_line_color"] = "#ff0000"
        style2["hz_line_color"] = "#ff0000"
        style2["vt_line_width"] = 2
        style2["hz_line_width"] = 2
        style2["vt_line_type"] = 0  # 0 solid, 1 dashed, 2 dotted
        style2["hz_line_type"] = 0
        style2["size"] = 0
        for node in self.coa_roots:
            node.set_style(style1)
            node.img_style["size"] = 0
            for des in node.get_descendants():
                des.set_style(style2)
                des.img_style["size"] = 0
        return [all_taxa_name], sp_list

    def print_species(self, save_file):
        cnt = 1
        file3 = open(os.path.join(save_file, "partition.txt"), "w+")
        for sp in self.species_list:
            print("Species " + repr(cnt) + ":", file=file3)
            cnt = cnt + 1
            taxas = ""
            for taxa in sp:
                taxas = taxas + taxa.name + ", "
            print("" + taxas[:-1], file=file3)

    def print_species_spart(self, save_file):
        cnt = 1
        file3 = open(os.path.join(save_file, "partition.spart"), "w+")
        # for sp in self.species_list:
        #     print("Species " + repr(cnt) + ":", file= file3)
        #     cnt = cnt + 1
        #     taxas = ""
        #     for taxa in sp:
        #         print(taxa)
        #         taxas = taxas + taxa.name + ", "
        #     print("" + taxas[:-1], file= file3)

        file3.write("Filename=GMYC delimitation\n")
        file3.write(f'{datetime.datetime.now().astimezone().isoformat()}\n\n')
        file3.write(f"Npartition={1};GMYC\n")

        file3.write(f'Nsamples={sum(len(sp) for sp in self.species_list)}\n')
        file3.write(
            f'Nsubsets={len(self.species_list)};{",".join(["?" for i in range(len(self.species_list))])}\n\n'
        )
        file3.write("#this is my first comment\n")
        file3.write("#this is my second comment\n\n")
        file3.write("Assignment\n")
        cnt = 1
        for sp in self.species_list:
            print(repr(sp))
            for taxa in sp:
                print(repr(taxa))
                xx = taxa.name + "\t" + repr(cnt) + ";" + "?"
                file3.write(f"{xx}\n")
            cnt += 1

        file3.write("\nPartition_score=\n")
        file3.close()

    def output_species(self, taxa_order=[]):
        if len(taxa_order) == 0:
            taxa_order = self.tree.get_leaf_names()
        num_taxa = 0
        for sp in self.species_list:

            for taxa in sp:
                num_taxa = num_taxa + 1
        if not len(taxa_order) == num_taxa:

            print("error error, taxa_order != num_taxa!")
            return None, None
        else:
            partion = [-1] * num_taxa
            cnt = 1
            for sp in self.species_list:

                for taxa in sp:
                    idx = taxa_order.index(taxa.name)
                    partion[idx] = cnt
                cnt = cnt + 1
            return taxa_order, partion

    def num_lineages(self, wt_list, save_file):
        nl_list = []
        times = []
        last_time = 0.0
        for wt in wt_list:
            nl_list.append(wt.get_num_branches())
            times.append(last_time)
            last_time = wt.length + last_time
        plt.plot(times, nl_list)
        plt.ylabel('Number of lineages')
        plt.xlabel('Time')
        plt.savefig(os.path.join(save_file, "Time_Lines.png"))
Example #7
0
import os, uuid
from ete3 import Tree

for file in os.listdir("/Users/David/Downloads/Chunks"):
    if file.endswith(".tre"):
        outname = "/Users/David/Downloads/Chunks/Chunks_90/" + str(file)
        t = Tree(file, format=0)

        print t.get_ascii(attributes=['support', 'name'])

        for node in t.get_descendants():
            if not node.is_leaf() and node.support <= 0.9:
                node.delete()

        print t.get_ascii(attributes=['support', 'name'])

        t.write(format=0, outfile=outname)
Example #8
0
class um_tree:
    def __init__(self, tree, PATH):
        self.tree = Tree(tree, format=1)
        self.tree2 = open(tree)
        self.tree.resolve_polytomy(default_dist=0.000001, recursive=True)
        self.tree.dist = 0
        self.tree.add_feature("age", 0)
        self.nodes = self.tree.get_descendants()
        self.PATH = PATH
        internal_node = []
        cnt = 0
        for n in self.nodes:
            node_age = n.get_distance(self.tree)
            n.add_feature("age", node_age)
            if not n.is_leaf():
                n.add_feature("id", cnt)
                cnt = cnt + 1
                internal_node.append(n)
        self.nodes = internal_node
        one_leaf = self.tree.get_farthest_node()[0]
        one_leaf.add_feature("id", cnt + 1)
        if one_leaf.is_leaf():
            self.nodes.append(one_leaf)
        self.nodes.sort(key=self.__compare_node)
        self.species_list = []
        self.coa_roots = None

    def __compare_node(self, node):
        return node.age

    def get_waiting_times(self, threshold_node=None, threshold_node_idx=0):
        wt_list = []
        reach_t = False
        curr_age = 0.0
        curr_spe = 2
        curr_num_coa = 0
        coa_roots = []
        min_brl = 1000
        num_spe = -1

        if threshold_node == None:
            threshold_node = self.nodes[threshold_node_idx]

        last_coa_num = 0
        tcnt = 0
        for node in self.nodes:
            num_children = len(node.get_children())
            wt = None
            times = node.age - curr_age
            if times >= 0:
                if times < min_brl and times > 0:
                    min_brl = times
                curr_age = node.age
                assert curr_spe >= 0

                if reach_t:
                    if tcnt == 0:
                        last_coa_num = 2
                    fnode = node.up
                    coa_root = None

                    idx = 0
                    while not fnode.is_root():
                        idx = 0
                        for coa_r in coa_roots:
                            if coa_r.id == fnode.id:
                                coa_root = coa_r
                                break
                            idx = idx + 1

                        if coa_root != None:
                            break
                        else:
                            fnode = fnode.up

                    wt = waiting_time(length=times,
                                      num_coas=curr_num_coa,
                                      num_lines=curr_spe)

                    for coa_r in coa_roots:
                        coa = coalescent(num_individual=coa_r.curr_n)
                        wt.coas.add_coalescent(coa)

                    wt.coas.coas_idx = last_coa_num
                    wt.num_curr_coa = last_coa_num
                    if (coa_root == None
                        ):  # here can be modified to use multiple T
                        curr_spe = curr_spe - 1
                        curr_num_coa = curr_num_coa + 1
                        node.add_feature("curr_n", 2)
                        coa_roots.append(node)
                        last_coa_num = 2
                    else:
                        curr_n = coa_root.curr_n
                        coa_root.add_feature("curr_n", curr_n + 1)
                        last_coa_num = curr_n + 1
                    tcnt = tcnt + 1
                else:
                    if node.id == threshold_node.id:
                        reach_t = True
                        tcnt = 0
                        wt = waiting_time(length=times,
                                          num_coas=0,
                                          num_lines=curr_spe)
                        num_spe = curr_spe
                        curr_spe = curr_spe - 1
                        curr_num_coa = 2
                        node.add_feature("curr_n", 2)
                        coa_roots.append(node)
                    else:
                        wt = waiting_time(length=times,
                                          num_coas=0,
                                          num_lines=curr_spe)
                        curr_spe = curr_spe + 1
                if times > 0.00000001:
                    wt_list.append(wt)

        for wt in wt_list:
            wt.count_num_lines()

        self.species_list = []
        all_coa_leaves = []
        self.coa_roots = coa_roots
        for coa_r in coa_roots:
            leaves = coa_r.get_leaves()
            all_coa_leaves.extend(leaves)
            self.species_list.append(leaves)

        all_leaves = self.tree.get_leaves()
        for leaf in all_leaves:
            if leaf not in all_coa_leaves:
                self.species_list.append([leaf])

        return wt_list, num_spe

    def show(self, wt_list):
        cnt = 1
        for wt in wt_list:
            print(("Waitting interval " + repr(cnt)))
            print(wt)
            cnt = cnt + 1

    def get_species(self):
        sp_list = []
        for sp in self.species_list:
            spe = []
            for taxa in sp:
                spe.append(taxa.name)
            sp_list.append(spe)

        all_taxa_name = []

        # self.tree.convert_to_ultrametric(tree_length = 1.0, strategy='balanced')

        for leaf in self.tree.get_leaves():
            all_taxa_name.append(leaf.name)

        style0 = NodeStyle()
        style0["fgcolor"] = "#000000"
        # style2["shape"] = "circle"
        style0["vt_line_color"] = "#0000aa"
        style0["hz_line_color"] = "#0000aa"
        style0["vt_line_width"] = 2
        style0["hz_line_width"] = 2
        style0["vt_line_type"] = 0  # 0 solid, 1 dashed, 2 dotted
        style0["hz_line_type"] = 0
        style0["size"] = 0

        for node in self.tree.get_descendants():
            node.set_style(style0)
            node.img_style["size"] = 0
        self.tree.set_style(style0)
        self.tree.img_style["size"] = 0

        style1 = NodeStyle()
        style1["fgcolor"] = "#000000"
        # style2["shape"] = "circle"
        style1["vt_line_color"] = "#ff0000"
        style1["hz_line_color"] = "#0000aa"
        style1["vt_line_width"] = 2
        style1["hz_line_width"] = 2
        style1["vt_line_type"] = 0  # 0 solid, 1 dashed, 2 dotted
        style1["hz_line_type"] = 0
        style1["size"] = 0

        style2 = NodeStyle()
        style2["fgcolor"] = "#0f0f0f"
        # style2["shape"] = "circle"
        style2["vt_line_color"] = "#ff0000"
        style2["hz_line_color"] = "#ff0000"
        style2["vt_line_width"] = 2
        style2["hz_line_width"] = 2
        style2["vt_line_type"] = 0  # 0 solid, 1 dashed, 2 dotted
        style2["hz_line_type"] = 0
        style2["size"] = 0

        for node in self.coa_roots:
            node.set_style(style1)
            node.img_style["size"] = 0
            for des in node.get_descendants():
                des.set_style(style2)
                des.img_style["size"] = 0

        return [all_taxa_name], sp_list

    def print_species(self):
        # tree_path = os.path.dirname(self.tree2.name)
        sp_out = open(os.path.join(self.PATH, "GMYC/GMYC_MOTU.txt"), "w+")
        cnt = 1
        for sp in self.species_list:
            # 			print("Species " + repr(cnt) + ":")
            sp_out.write("Species " + repr(cnt) + "\n")
            cnt = cnt + 1
            taxas = ""
            for taxa in sp:
                taxas = taxas + taxa.name + ", "
            # 			print("	" + taxas[:-1])
            sp_out.write("	" + taxas[:-1] + "\n")

    def output_species(self, taxa_order=[]):
        """taxa_order is a list of taxa names, the paritions will be output as the same order"""
        if len(taxa_order) == 0:
            taxa_order = self.tree.get_leaf_names()

        num_taxa = 0
        for sp in self.species_list:
            for taxa in sp:
                num_taxa = num_taxa + 1
        if not len(taxa_order) == num_taxa:
            print("error error, taxa_order != num_taxa!")
            return None, None
        else:
            partion = [-1] * num_taxa
            cnt = 1
            for sp in self.species_list:
                for taxa in sp:
                    idx = taxa_order.index(taxa.name)
                    partion[idx] = cnt
                cnt = cnt + 1
            return taxa_order, partion

    def num_lineages(self, wt_list):
        nl_list = []
        times = []
        last_time = 0.0
        for wt in wt_list:
            nl_list.append(wt.get_num_branches())
            times.append(last_time)
            last_time = wt.length + last_time

        plt.plot(times, nl_list)
        plt.ylabel("Number of lineages")
        plt.xlabel("Time")
        plt.savefig("Time_Lines")
        plt.show()
    if (not os.path.exists(arguments.tree_file2)):
        print("Tree file 2 not found.")
        sys.exit(1)

    # Read file, check it is in the correct format.
    try:
        print("Reading Trees...")
        tree1 = Tree(arguments.tree_file1)
        tree2 = Tree(arguments.tree_file2)
        print("Setting root on midpoint...")
        tree1_outgroup = tree1.get_midpoint_outgroup()
        tree1.set_outgroup(tree1_outgroup)
        tree2_outgroup = tree2.get_midpoint_outgroup()
        tree2.set_outgroup(tree2_outgroup)
        print("Collapsing nodes with branch distance = 0...")
        for node_tree1 in tree1.get_descendants():
            #print(node2.dist)
            #if not node2.is_leaf() and round(node2.dist,4) <= 1:
            if not node_tree1.is_leaf(
            ) and node_tree1.dist <= 1.00000050002909e-06:
                node_tree1.delete()
        for node_tree2 in tree2.get_descendants():
            #print(node2.dist)
            #if not node2.is_leaf() and round(node2.dist,4) <= 1:
            if not node_tree2.is_leaf(
            ) and node_tree2.dist <= 1.00000050002909e-06:
                node_tree2.delete()
        print("Trees read successfully.")
    except:
        print("Trees couldn't be loaded.")
        raise
    idToDescendants = dict()
    # Now we want to get the calibrations according to the options that have been user-input.

    t_begin = Tree()

    # Balanced or not?
    if ('y' in balanced):
        # Getting calibrations from both sides of the root
        t_begin = t
    else:
        # Getting calibrations only from one side
        choices = [0,1]
        choice  = random.choice(choices)
        print("Choosing calibrations from subtree: ", choice)
        t_begin = t.get_children()[choice]
    print("Number of nodes in sampled subtree: ", len(t_begin.get_descendants()))

    id2Height = getInternalNodeHeights( t_begin )
    nodeId2LeafListRef, leafList2NodeIdRef, idToDescendants = getNameToLeavesAndIdToDescendantIdsLink( t_begin )
    # Let's order the nodes according to their heights:
    d_ascending = OrderedDict(sorted(id2Height.items(), key=lambda kv: kv[1]))

    # Removing the root node
    d_ascending.popitem()

    calibrated_nodes_red = list()
    calibrated_nodes_blue = list()

    if ('y' in old):
        print("\tFavouring ancient calibrations\n")
        weights = biasedWeights(d_ascending)
                 #'beng|oriy', 
                 #'awad|bhoj|mait', 
                 'vlax|doma|doma', 
                 #'marw|dhun', 
                 #'mewa|bagr', 
                 'MP', 
                 #'jude|luri|bakh', 
                 #'gila|sang', 
                 #'awad|bhoj', 
                 #'vlax|doma', 
                 #'midd|dari', 
                 #'luri|bakh'
                 ]


for n in ctree.get_descendants():
    if not n.is_leaf():
        if n.name not in nodes_to_keep:
            #print (n.name)
            n.delete()



ctree.write(format=9,outfile='constraint2.tre')


clade_calib_anc = {
    'PIA':('sans1269','Unif(3400,3000)'),
    'OP':('oldp1254','Unif(2500,2300)'),
    'MP':('pahl1241','Unif(1800,1400)')
}
Example #12
0
class exponential_mixture:
    """ML search PTP, to use: __init__(), search() and count_species()"""
    def __init__(self, tree, sp_rate = 0, fix_sp_rate = False, max_iters = 20000, min_br = 0.0001):
        self.min_brl = min_br
        self.tree = Tree(tree, format = 1)
        self.tree.resolve_polytomy(recursive=True)
        self.tree.dist = 0.0
        self.fix_spe_rate = fix_sp_rate
        self.fix_spe = sp_rate
        self.max_logl = float("-inf") 
        self.max_setting = None
        self.null_logl = 0.0
        self.null_model()
        self.species_list = None
        self.counter = 0
        self.setting_set = set([])
        self.max_num_search = max_iters


    def null_model(self):
        coa_br = []
        all_nodes = self.tree.get_descendants()
        for node in all_nodes:
            if node.dist > self.min_brl:
                coa_br.append(node.dist)
        e1 = exp_distribution(coa_br)
        self.null_logl = e1.sum_log_l()
        return e1.rate


    def __compare_node(self, node):
        return node.dist


    def re_rooting(self):
        node_list = self.tree.get_descendants()
        node_list.sort(key=self.__compare_node)
        node_list.reverse()
        rootnode = node_list[0]
        self.tree.set_outgroup(rootnode)
        self.tree.dist = 0.0


    def comp_num_comb(self):
        for node in self.tree.traverse(strategy='postorder'):
            if node.is_leaf():
                node.add_feature("cnt", 1.0)
            else:
                acum = 1.0
                for child in node.get_children():
                    acum = acum * child.cnt
                acum = acum + 1.0
                node.add_feature("cnt", acum)
        return self.tree.cnt


    def next(self, sp_setting):
        self.setting_set.add(frozenset(sp_setting.spe_nodes))
        logl = sp_setting.get_log_l()
        if logl > self.max_logl:
            self.max_logl = logl
            self.max_setting = sp_setting
        for node in sp_setting.active_nodes:
            if node.is_leaf():
                pass
            else:
                childs = node.get_children()
                sp_nodes = []
                for child in childs:
                    sp_nodes.append(child)
                for nod in sp_setting.spe_nodes:
                    sp_nodes.append(nod)
                new_sp_setting = species_setting(spe_nodes = sp_nodes, root = sp_setting.root, sp_rate = sp_setting.spe_rate, fix_sp_rate = sp_setting.fix_spe_rate, minbr = self.min_brl)
                if frozenset(sp_nodes) in self.setting_set:
                    pass
                else:
                    self.next(new_sp_setting)


    def H0(self, reroot = True):
        self.H1(reroot)
        self.H2(reroot = False)
        self.H3(reroot = False)


    def H1(self, reroot = True):
        if reroot:
            self.re_rooting()
            
        #self.init_tree()
        sorted_node_list = self.tree.get_descendants()
        sorted_node_list.sort(key=self.__compare_node)
        sorted_node_list.reverse()
        
        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        first_setting = species_setting(spe_nodes = first_node_list, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
        last_setting = first_setting
        max_logl = last_setting.get_log_l()
        max_setting = last_setting
        
        for node in sorted_node_list:
            if node not in last_setting.spe_nodes:
                curr_sp_nodes = []
                for nod in last_setting.spe_nodes:
                    curr_sp_nodes.append(nod)
                
                chosen_branching_node = node.up #find the father of this new node
                if chosen_branching_node in last_setting.spe_nodes:
                    for nod in chosen_branching_node.get_children():
                        if nod not in curr_sp_nodes:
                            curr_sp_nodes.append(nod)
                else:
                    for nod in chosen_branching_node.get_children():
                        if nod not in curr_sp_nodes:
                            curr_sp_nodes.append(nod)
                    while not chosen_branching_node.is_root():
                        chosen_branching_node = chosen_branching_node.up
                        for nod in chosen_branching_node.get_children():
                            if nod not in curr_sp_nodes:
                                curr_sp_nodes.append(nod)
                        if chosen_branching_node in last_setting.spe_nodes:
                            break
                new_setting = species_setting(spe_nodes = curr_sp_nodes, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
                new_logl = new_setting.get_log_l()
                if new_logl> max_logl:
                    max_logl = new_logl
                    max_setting = new_setting 
                last_setting = new_setting
                
            else:
                """node already is a speciation node, do nothing"""
                pass
        
        if max_logl > self.max_logl:
            self.max_logl = max_logl
            self.max_setting = max_setting


    def H2(self, reroot = True):
        """Greedy"""
        if reroot:
            self.re_rooting()
            
        #self.init_tree()
        sorted_node_list = self.tree.get_descendants()
        sorted_node_list.sort(key=self.__compare_node)
        sorted_node_list.reverse()
        
        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        first_setting = species_setting(spe_nodes = first_node_list, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
        last_setting = first_setting
        max_logl = last_setting.get_log_l()
        max_setting = last_setting
        contin_flag = True 
        
        
        while contin_flag:
            curr_max_logl = float("-inf") 
            curr_max_setting = None
            contin_flag = False
            for node in last_setting.active_nodes:
                if node.is_leaf():
                    pass
                else:
                    contin_flag = True 
                    childs = node.get_children()
                    sp_nodes = []
                    for child in childs:
                        sp_nodes.append(child)
                    for nod in last_setting.spe_nodes:
                        sp_nodes.append(nod)
                    new_sp_setting = species_setting(spe_nodes = sp_nodes, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
                    logl = new_sp_setting.get_log_l()
                    if logl > curr_max_logl:
                        curr_max_logl = logl
                        curr_max_setting = new_sp_setting
            
            if curr_max_logl > max_logl:
                max_setting = curr_max_setting
                max_logl = curr_max_logl
            
            last_setting = curr_max_setting
            
        if max_logl > self.max_logl:
            self.max_logl = max_logl
            self.max_setting = max_setting


    def H3(self, reroot = True):
        if reroot:
            self.re_rooting()
        sorted_node_list = self.tree.get_descendants()
        sorted_node_list.sort(key=self.__compare_node)
        sorted_node_list.reverse()
        sorted_br = []
        for node in sorted_node_list:
            sorted_br.append(node.dist)
        maxlogl = float("-inf") 
        maxidx = -1
        for i in range(len(sorted_node_list))[1:]:
            l1 = sorted_br[0:i]
            l2 = sorted_br[i:]
            e1 = exp_distribution(l1)
            e2 = exp_distribution(l2)
            logl = e1.sum_log_l() + e2.sum_log_l()
            if logl > maxlogl:
                maxidx = i
                maxlogl = logl
        
        target_nodes = sorted_node_list[0:maxidx]
        
        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        first_setting = species_setting(spe_nodes = first_node_list, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
        last_setting = first_setting
        max_logl = last_setting.get_log_l()
        max_setting = last_setting
        contin_flag = True 
        target_node_cnt = 0
        while contin_flag:
            curr_max_logl = float("-inf") 
            curr_max_setting = None
            contin_flag = False
            unchanged_flag = True
            for node in last_setting.active_nodes:
                if node.is_leaf():
                    pass
                else:
                    contin_flag = True 
                    childs = node.get_children()
                    sp_nodes = []
                    flag = False
                    for child in childs:
                        if child in target_nodes:
                            flag = True
                            #target_nodes.remove(child)
                    if flag:
                        unchanged_flag = False
                        for child in childs:
                            sp_nodes.append(child)
                        for nod in last_setting.spe_nodes:
                            sp_nodes.append(nod)
                        new_sp_setting = species_setting(spe_nodes = sp_nodes, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
                        logl = new_sp_setting.get_log_l()
                        if logl > curr_max_logl:
                            curr_max_logl = logl
                            curr_max_setting = new_sp_setting
            if not unchanged_flag:
                target_node_cnt = target_node_cnt + 1
                if curr_max_logl > max_logl:
                    max_setting = curr_max_setting
                    max_logl = curr_max_logl
                last_setting = curr_max_setting
            
            if len(target_nodes) == target_node_cnt:
                contin_flag = False
            if contin_flag and unchanged_flag and last_setting!= None:
                for node in last_setting.active_nodes:
                    if node.is_leaf():
                        pass
                    else:
                        childs = node.get_children()
                        sp_nodes = []
                        for child in childs:
                            sp_nodes.append(child)
                        for nod in last_setting.spe_nodes:
                            sp_nodes.append(nod)
                        new_sp_setting = species_setting(spe_nodes = sp_nodes, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
                        logl = new_sp_setting.get_log_l()
                        if logl > curr_max_logl:
                            curr_max_logl = logl
                            curr_max_setting = new_sp_setting
                if curr_max_logl > max_logl:
                    max_setting = curr_max_setting
                    max_logl = curr_max_logl
                last_setting = curr_max_setting
                
        if max_logl > self.max_logl:
            self.max_logl = max_logl
            self.max_setting = max_setting


    def Brutal(self, reroot = False):
        if reroot:
            self.re_rooting()
        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        num_s = self.comp_num_comb()
        if num_s > self.max_num_search:
            print("Too many search iterations: " + repr(num_s) + ", using H0 instead!!!")
            self.H0(reroot = False)
        else:
            first_setting = species_setting(spe_nodes = first_node_list, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
            self.next(first_setting)


    def search(self, strategy = "H1", reroot = False):
        if strategy == "H1":
            self.H1(reroot)
        elif strategy == "H2":
            self.H2(reroot)
        elif strategy == "H3":
            self.H3(reroot)
        elif strategy == "Brutal":
            self.Brutal(reroot)
        else:
            self.H0(reroot)


    def count_species(self, print_log = True, pv = 0.001):
        lhr = lh_ratio_test(self.null_logl, self.max_logl, 1)
        pvalue = lhr.get_p_value()
        if print_log:
            print("Speciation rate: " + "{0:.3f}".format(self.max_setting.rate2))
            print("Coalesecnt rate: " + "{0:.3f}".format(self.max_setting.rate1))
            print("Null logl: " + "{0:.3f}".format(self.null_logl))
            print("MAX logl: " + "{0:.3f}".format(self.max_logl))
            print("P-value: " + "{0:.3f}".format(pvalue))
            spefit, speaw = self.max_setting.e2.ks_statistic()
            coafit, coaaw = self.max_setting.e1.ks_statistic()
            print("Kolmogorov-Smirnov test for model fitting:")
            print("Speciation: " + "Dtest = {0:.3f}".format(spefit) + " " + speaw)
            print("Coalescent: " + "Dtest = {0:.3f}".format(coafit) + " " + coaaw)
        if pvalue < pv:
            num_sp, self.species_list = self.max_setting.count_species()
            return num_sp
        else:
            self.species_list = []
            self.species_list.append(self.tree.get_leaf_names()) 
            return 1


    def whitening_search(self, strategy = "H1", reroot = False, pv = 0.001):
        self.search(strategy, reroot, pv)
        num_sp, self.species_list = self.max_setting.count_species()
        spekeep = self.max_setting.whiten_species()
        self.tree.prune(spekeep)
        self.max_logl = float("-inf") 
        self.max_setting = None
        self.null_logl = 0.0
        self.null_model()
        self.species_list = None
        self.counter = 0
        self.setting_set = set([])
        self.search(strategy, reroot, pv)


    def print_species(self):
        cnt = 1
        for sp in self.species_list:
            print("Species " + repr(cnt) + ":")
            for leaf in sp:
                print("          " + leaf)
            cnt = cnt + 1


    def output_species(self, taxa_order = []):
        """taxa_order is a list of taxa names, the paritions will be output as the same order"""
        if len(taxa_order) == 0:
            taxa_order = self.tree.get_leaf_names()
        
        num_taxa = 0
        for sp in self.species_list:
            for leaf in sp:
                num_taxa = num_taxa + 1
        if not len(taxa_order) == num_taxa:
            print("error error, taxa_order != num_taxa!")
            return None, None
        else: 
            partion = [-1] * num_taxa
            cnt = 1
            for sp in self.species_list:
                for leaf in sp:
                    idx = taxa_order.index(leaf)
                    partion[idx] = cnt
                cnt = cnt + 1
            return taxa_order, partion
Example #13
0
print t
#          /-A
#---------|
#         |          /-B
#          \--------|
#                   |          /-C
#                    \--------|
#                              \-D
# Counts leaves within the tree
nleaves = 0
for leaf in t.get_leaves():
    nleaves += 1
print "This tree has", nleaves, "terminal nodes"
# But, like this is much simpler :)
nleaves = len(t)
print "This tree has", nleaves, "terminal nodes [proper way: len(tree) ]"
# Counts leaves within the tree
ninternal = 0
for node in t.get_descendants():
    if not node.is_leaf():
        ninternal += 1
print "This tree has", ninternal, "internal nodes"
# Counts nodes with whose distance is higher than 0.3
nnodes = 0
for node in t.get_descendants():
    if node.dist > 0.3:
        nnodes += 1
# or, translated into a better pythonic
nnodes = len([n for n in t.get_descendants() if n.dist > 0.3])
print "This tree has", nnodes, "nodes with a branch length > 0.3"
Example #14
0
print t
#          /-A
#---------|
#         |          /-B
#          \--------|
#                   |          /-C
#                    \--------|
#                              \-D
# Counts leaves within the tree
nleaves = 0
for leaf in t.get_leaves():
    nleaves += 1
print "This tree has", nleaves, "terminal nodes"
# But, like this is much simpler :)
nleaves = len(t)
print "This tree has", nleaves, "terminal nodes [proper way: len(tree) ]"
# Counts leaves within the tree
ninternal = 0
for node in t.get_descendants():
    if not node.is_leaf():
        ninternal +=1
print "This tree has", ninternal,  "internal nodes"
# Counts nodes with whose distance is higher than 0.3
nnodes = 0
for node in t.get_descendants():
    if node.dist >  0.3:
        nnodes +=1
# or, translated into a better pythonic
nnodes = len([n for n in t.get_descendants() if n.dist>0.3])
print "This tree has", nnodes,  "nodes with a branch length > 0.3"
def random_tree(trees):
    '''
    Randomly choose a tree and find two nodes for inheritance
    '''
    #Randomly choose a tree
    while True:
        tree = choice(open(trees).readlines())
        t = Tree(tree, format=1)
        tips = []
        nodes = []
        k = 1
        for node in t.traverse():
            if node.is_leaf():
                tips.append(node.name)
            elif not node.is_root():
                node.add_features(name='n' + str(k))
                nodes.append(node.name)
                k += 1
        nodes = list(filter(None, nodes))

        #Randomly choose two nodes for inheritance
        timeout1 = time.time() + 60
        timeout2 = time.time() + 90
        while True:
            rn2 = Tree(tree, format=1)
            rn = sample(nodes, 2)
            rn1 = t.search_nodes(name=rn[0])[0]
            rn2 = t.search_nodes(name=rn[1])[0]
            if time.time() <= timeout1:
                if (len(rn1.get_leaves()) <= 2) or (len(rn2.get_leaves()) <=
                                                    2):
                    continue
                elif rn2 in rn1.get_descendants():
                    continue
                elif rn1 in rn2.get_descendants():
                    continue
                elif rn2 in rn1.get_sisters():
                    continue
                else:
                    r_tips = []
                    r_nodes = []
                    for node in rn1.traverse():
                        if node.is_leaf():
                            r_tips.append(node.name)
                        else:
                            r_nodes.append(node.name)
                    root1 = t.get_common_ancestor(r_tips)
                    root2 = []
                    for node in rn2.traverse():
                        if node.is_leaf():
                            r_tips.append(node.name)
                            root2.append(node.name)
                        else:
                            r_nodes.append(node.name)
                    root2 = t.get_common_ancestor(root2)
                    dist = t.get_distance(root1, root2, topology_only=True)
                    tree = topology_dist(t, nodes, r_nodes, r_tips,
                                         branchProbabilityDist)
                    return [tree, nodes, tips, r_nodes, r_tips, dist]
            elif time.time() <= timeout2:
                if (len(rn1.get_leaves()) < 2) or (len(rn2.get_leaves()) < 2):
                    continue
                elif rn2 in rn1.get_descendants():
                    continue
                elif rn1 in rn2.get_descendants():
                    continue
                elif rn2 in rn1.get_sisters():
                    continue
                else:
                    r_tips = []
                    r_nodes = []
                    root1 = []
                    for node in rn1.traverse():
                        if node.is_leaf():
                            r_tips.append(node.name)
                            root1.append(node.name)
                        else:
                            r_nodes.append(node.name)
                    root1 = t.get_common_ancestor(root1)
                    root2 = []
                    for node in rn2.traverse():
                        if node.is_leaf():
                            r_tips.append(node.name)
                            root2.append(node.name)
                        else:
                            r_nodes.append(node.name)
                    root2 = t.get_common_ancestor(root2)
                    dist = t.get_distance(root1, root2, topology_only=True)
                    tree = topology_dist(t, nodes, r_nodes, r_tips,
                                         branchProbabilityDist)
                    return [tree, nodes, tips, r_nodes, r_tips, dist]
            else:
                break
    cMtx = pd.DataFrame([[int(s in cWords[l]) for s in sounds] for l in cTaxa],
                        index=cTaxa,
                        columns=[c + ':' + s
                                 for s in sounds]).reindex(taxa,
                                                           fill_value='-')
    scMtx = pd.concat([scMtx, cMtx], axis=1)

nexCharOutput(scMtx, '../data/indoiranian.sc.nex', datatype="Standard")

# n, m = ccMtx.shape[1], scMtx.shape[1]
# cc_sc = pd.concat([ccMtx, scMtx], axis=1)

# nexCharOutput(cc_sc, 'indoiranian.nex', datatype='restriction')

nodes = np.array([
    nd for nd in ctree.get_descendants()
    if not nd.is_root() and not nd.is_leaf()
])

for i, nd in enumerate(nodes):
    nd.name = 'clade' + str(i + 1).rjust(2, '0')


def nname(x):
    if x.is_leaf():
        return '"' + x.name + '"'
    else:
        return x.name


rev = ""
Example #17
0
from ete3 import Tree
tree = Tree( '(A:1,(B:1,(C:1,D:1):0.5):0.5);' )
# Prints the name of every leaf under the tree root
print "Leaf names:"
for leaf in tree.get_leaves():
    print leaf.name
# Label nodes as terminal or internal. If internal, saves also the
# number of leaves that it contains.
print "Labeled tree:"
for node in tree.get_descendants():
    if node.is_leaf():
        node.add_features(ntype="terminal")
    else:
        node.add_features(ntype="internal", size=len(node))
# Gets the extended newick of the tree including new node features
print tree.write(features=[])
Example #18
0
class exponential_mixture:
    """ML search PTP, to use: __init__(), search() and count_species()"""
    def __init__(
        self,
        tree,
        sp_rate=0,
        fix_sp_rate=False,
        max_iters=20000,
        min_br=0.0001,
    ):
        self.min_brl = min_br
        self.tree = Tree(tree, format=1)
        self.tree.resolve_polytomy(recursive=True)
        self.tree.dist = 0.0
        self.fix_spe_rate = fix_sp_rate
        self.fix_spe = sp_rate
        self.max_logl = float("-inf")
        self.max_setting = None
        self.null_logl = 0.0
        self.null_model()
        self.species_list = None
        self.counter = 0
        self.setting_set = set([])
        self.max_num_search = max_iters

    def null_model(self):
        coa_br = []
        all_nodes = self.tree.get_descendants()
        for node in all_nodes:
            if node.dist > self.min_brl:
                coa_br.append(node.dist)
        e1 = exp_distribution(coa_br)
        self.null_logl = e1.sum_log_l()
        return e1.rate

    def __compare_node(self, node):
        return node.dist

    def re_rooting(self):
        node_list = self.tree.get_descendants()
        node_list.sort(key=self.__compare_node)
        node_list.reverse()
        rootnode = node_list[0]
        self.tree.set_outgroup(rootnode)
        self.tree.dist = 0.0

    def comp_num_comb(self):
        for node in self.tree.traverse(strategy="postorder"):
            if node.is_leaf():
                node.add_feature("cnt", 1.0)
            else:
                acum = 1.0
                for child in node.get_children():
                    acum = acum * child.cnt
                acum = acum + 1.0
                node.add_feature("cnt", acum)
        return self.tree.cnt

    def next(self, sp_setting):
        self.setting_set.add(frozenset(sp_setting.spe_nodes))
        logl = sp_setting.get_log_l()
        if logl > self.max_logl:
            self.max_logl = logl
            self.max_setting = sp_setting
        for node in sp_setting.active_nodes:
            if node.is_leaf():
                pass
            else:
                childs = node.get_children()
                sp_nodes = []
                for child in childs:
                    sp_nodes.append(child)
                for nod in sp_setting.spe_nodes:
                    sp_nodes.append(nod)
                new_sp_setting = species_setting(
                    spe_nodes=sp_nodes,
                    root=sp_setting.root,
                    sp_rate=sp_setting.spe_rate,
                    fix_sp_rate=sp_setting.fix_spe_rate,
                    minbr=self.min_brl,
                )
                if frozenset(sp_nodes) in self.setting_set:
                    pass
                else:
                    self.next(new_sp_setting)

    def H0(self, reroot=True):
        self.H1(reroot)
        self.H2(reroot=False)
        self.run_h3(reroot=False)

    def H1(self, reroot=True):
        if reroot:
            self.re_rooting()

        # self.init_tree()
        sorted_node_list = self.tree.get_descendants()
        sorted_node_list.sort(key=self.__compare_node)
        sorted_node_list.reverse()

        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        first_setting = species_setting(
            spe_nodes=first_node_list,
            root=self.tree,
            sp_rate=self.fix_spe,
            fix_sp_rate=self.fix_spe_rate,
            minbr=self.min_brl,
        )
        last_setting = first_setting
        max_logl = last_setting.get_log_l()
        max_setting = last_setting

        for node in sorted_node_list:
            if node not in last_setting.spe_nodes:
                curr_sp_nodes = []
                for nod in last_setting.spe_nodes:
                    curr_sp_nodes.append(nod)

                chosen_branching_node = (node.up
                                         )  # find the father of this new node
                if chosen_branching_node in last_setting.spe_nodes:
                    for nod in chosen_branching_node.get_children():
                        if nod not in curr_sp_nodes:
                            curr_sp_nodes.append(nod)
                else:
                    for nod in chosen_branching_node.get_children():
                        if nod not in curr_sp_nodes:
                            curr_sp_nodes.append(nod)
                    while not chosen_branching_node.is_root():
                        chosen_branching_node = chosen_branching_node.up
                        for nod in chosen_branching_node.get_children():
                            if nod not in curr_sp_nodes:
                                curr_sp_nodes.append(nod)
                        if chosen_branching_node in last_setting.spe_nodes:
                            break
                new_setting = species_setting(
                    spe_nodes=curr_sp_nodes,
                    root=self.tree,
                    sp_rate=self.fix_spe,
                    fix_sp_rate=self.fix_spe_rate,
                    minbr=self.min_brl,
                )
                new_logl = new_setting.get_log_l()
                if new_logl > max_logl:
                    max_logl = new_logl
                    max_setting = new_setting
                last_setting = new_setting

            else:
                """node already is a speciation node, do nothing"""
                pass

        if max_logl > self.max_logl:
            self.max_logl = max_logl
            self.max_setting = max_setting

    def H2(self, reroot=True):
        """Greedy"""
        if reroot:
            self.re_rooting()

        # self.init_tree()
        sorted_node_list = self.tree.get_descendants()
        sorted_node_list.sort(key=self.__compare_node)
        sorted_node_list.reverse()

        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        first_setting = species_setting(
            spe_nodes=first_node_list,
            root=self.tree,
            sp_rate=self.fix_spe,
            fix_sp_rate=self.fix_spe_rate,
            minbr=self.min_brl,
        )
        last_setting = first_setting
        max_logl = last_setting.get_log_l()
        max_setting = last_setting
        contin_flag = True

        while contin_flag:
            curr_max_logl = float("-inf")
            curr_max_setting = None
            contin_flag = False
            for node in last_setting.active_nodes:
                if node.is_leaf():
                    pass
                else:
                    contin_flag = True
                    childs = node.get_children()
                    sp_nodes = []
                    for child in childs:
                        sp_nodes.append(child)
                    for nod in last_setting.spe_nodes:
                        sp_nodes.append(nod)
                    new_sp_setting = species_setting(
                        spe_nodes=sp_nodes,
                        root=self.tree,
                        sp_rate=self.fix_spe,
                        fix_sp_rate=self.fix_spe_rate,
                        minbr=self.min_brl,
                    )
                    logl = new_sp_setting.get_log_l()
                    if logl > curr_max_logl:
                        curr_max_logl = logl
                        curr_max_setting = new_sp_setting

            if curr_max_logl > max_logl:
                max_setting = curr_max_setting
                max_logl = curr_max_logl

            last_setting = curr_max_setting

        if max_logl > self.max_logl:
            self.max_logl = max_logl
            self.max_setting = max_setting

    def run_h3(self, reroot=True):
        if reroot:
            self.re_rooting()
        sorted_node_list = self.tree.get_descendants()
        sorted_node_list.sort(key=self.__compare_node)
        sorted_node_list.reverse()
        sorted_br = []
        for node in sorted_node_list:
            sorted_br.append(node.dist)
        maxlogl = float("-inf")
        maxidx = -1
        for i in range(len(sorted_node_list))[1:]:
            l1 = sorted_br[0:i]
            l2 = sorted_br[i:]
            e1 = exp_distribution(l1)
            e2 = exp_distribution(l2)
            logl = e1.sum_log_l() + e2.sum_log_l()
            if logl > maxlogl:
                maxidx = i
                maxlogl = logl

        target_nodes = sorted_node_list[0:maxidx]

        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        first_setting = species_setting(
            spe_nodes=first_node_list,
            root=self.tree,
            sp_rate=self.fix_spe,
            fix_sp_rate=self.fix_spe_rate,
            minbr=self.min_brl,
        )
        last_setting = first_setting
        max_logl = last_setting.get_log_l()
        max_setting = last_setting
        contin_flag = True
        target_node_cnt = 0
        while contin_flag:
            curr_max_logl = float("-inf")
            curr_max_setting = None
            contin_flag = False
            unchanged_flag = True
            for node in last_setting.active_nodes:
                if node.is_leaf():
                    pass
                else:
                    contin_flag = True
                    childs = node.get_children()
                    sp_nodes = []
                    flag = False
                    for child in childs:
                        if child in target_nodes:
                            flag = True
                    # target_nodes.remove(child)
                    if flag:
                        unchanged_flag = False
                        for child in childs:
                            sp_nodes.append(child)
                        for nod in last_setting.spe_nodes:
                            sp_nodes.append(nod)
                        new_sp_setting = species_setting(
                            spe_nodes=sp_nodes,
                            root=self.tree,
                            sp_rate=self.fix_spe,
                            fix_sp_rate=self.fix_spe_rate,
                            minbr=self.min_brl,
                        )
                        logl = new_sp_setting.get_log_l()
                        if logl > curr_max_logl:
                            curr_max_logl = logl
                            curr_max_setting = new_sp_setting
            if not unchanged_flag:
                target_node_cnt = target_node_cnt + 1
                if curr_max_logl > max_logl:
                    max_setting = curr_max_setting
                    max_logl = curr_max_logl
                last_setting = curr_max_setting

            if len(target_nodes) == target_node_cnt:
                contin_flag = False
            if contin_flag and unchanged_flag and last_setting != None:
                for node in last_setting.active_nodes:
                    if node.is_leaf():
                        pass
                    else:
                        childs = node.get_children()
                        sp_nodes = []
                        for child in childs:
                            sp_nodes.append(child)
                        for nod in last_setting.spe_nodes:
                            sp_nodes.append(nod)
                        new_sp_setting = species_setting(
                            spe_nodes=sp_nodes,
                            root=self.tree,
                            sp_rate=self.fix_spe,
                            fix_sp_rate=self.fix_spe_rate,
                            minbr=self.min_brl,
                        )
                        logl = new_sp_setting.get_log_l()
                        if logl > curr_max_logl:
                            curr_max_logl = logl
                            curr_max_setting = new_sp_setting
                if curr_max_logl > max_logl:
                    max_setting = curr_max_setting
                    max_logl = curr_max_logl
                last_setting = curr_max_setting

        if max_logl > self.max_logl:
            self.max_logl = max_logl
            self.max_setting = max_setting

    def Brutal(self, reroot=False):
        if reroot:
            self.re_rooting()
        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        num_s = self.comp_num_comb()
        if num_s > self.max_num_search:
            print("Too many search iterations: " + repr(num_s) +
                  ", using H0 instead!!!")
            self.H0(reroot=False)
        else:
            first_setting = species_setting(
                spe_nodes=first_node_list,
                root=self.tree,
                sp_rate=self.fix_spe,
                fix_sp_rate=self.fix_spe_rate,
                minbr=self.min_brl,
            )
            self.next(first_setting)

    def search(self, strategy="H1", reroot=False):
        if strategy == "H1":
            self.H1(reroot)
        elif strategy == "H2":
            self.H2(reroot)
        elif strategy == "H3":
            self.run_h3(reroot)
        elif strategy == "Brutal":
            self.Brutal(reroot)
        else:
            self.H0(reroot)

    def count_species(self, print_log=True, pv=0.001):
        lhr = lh_ratio_test(self.null_logl, self.max_logl, 1)
        pvalue = lhr.get_p_value()
        if print_log:
            print("Speciation rate: " +
                  "{0:.3f}".format(self.max_setting.rate2))
            print("Coalesecnt rate: " +
                  "{0:.3f}".format(self.max_setting.rate1))
            print("Null logl: " + "{0:.3f}".format(self.null_logl))
            print("MAX logl: " + "{0:.3f}".format(self.max_logl))
            print("P-value: " + "{0:.3f}".format(pvalue))
            spefit, speaw = self.max_setting.e2.ks_statistic()
            coafit, coaaw = self.max_setting.e1.ks_statistic()
            print("Kolmogorov-Smirnov test for model fitting:")
            print("Speciation: " + "Dtest = {0:.3f}".format(spefit) + " " +
                  speaw)
            print("Coalescent: " + "Dtest = {0:.3f}".format(coafit) + " " +
                  coaaw)
        if pvalue < pv:
            num_sp, self.species_list = self.max_setting.count_species()
            return num_sp
        else:
            self.species_list = []
            self.species_list.append(self.tree.get_leaf_names())
            return 1

    def whitening_search(self, strategy="H1", reroot=False, pv=0.001):
        self.search(strategy, reroot, pv)
        num_sp, self.species_list = self.max_setting.count_species()
        spekeep = self.max_setting.whiten_species()
        self.tree.prune(spekeep)
        self.max_logl = float("-inf")
        self.max_setting = None
        self.null_logl = 0.0
        self.null_model()
        self.species_list = None
        self.counter = 0
        self.setting_set = set([])
        self.search(strategy, reroot, pv)

    def print_species(self):
        cnt = 1
        for sp in self.species_list:
            print("Species " + repr(cnt) + ":")
            for leaf in sp:
                print("          " + leaf)
            cnt = cnt + 1

    def output_species(self, taxa_order=[]):
        """taxa_order is a list of taxa names, the paritions will be output as the same order"""
        if len(taxa_order) == 0:
            taxa_order = self.tree.get_leaf_names()

        num_taxa = 0
        for sp in self.species_list:
            for leaf in sp:
                num_taxa = num_taxa + 1
        if not len(taxa_order) == num_taxa:
            print("error error, taxa_order != num_taxa!")
            return None, None
        else:
            partion = [-1] * num_taxa
            cnt = 1
            for sp in self.species_list:
                for leaf in sp:
                    idx = taxa_order.index(leaf)
                    partion[idx] = cnt
                cnt = cnt + 1
            return taxa_order, partion
Example #19
0
def parse_tree(tree_file,
               logging,
               ete_format=0,
               set_root=False,
               resolve_polytomy=True,
               ladderize=True,
               method='midpoint',
               outgroup=''):
    """
    Parses newick formatted tree into an ETE tree object
    Parameters
    ----------
    tree_file [str] : Path to newick formatted tree
    logging [logging obj] : logging object
    ete_format [int] : Refer to documentation http://etetoolkit.org/docs/latest/reference/reference_tree.html
    set_root [Bool] : Set root for the tree
    resolve_polytomy [Bool]: Force bifurcations of the tree on polytomies
    ladderize [Bool] : Sort the branches of a given tree (swapping children nodes) according to the size
    method [str]: Method to root tree, either midpoint or outgroup
    outgroup [str] : Name of taxon to root tree on

    Returns
    -------

    ETE tree obj

    """

    logging.info("Attempting to parse tree file {}".format(tree_file))

    if not os.path.isfile(tree_file):
        logging.error(
            "Specified tree file {} is not found, please check that it exists".
            format(tree_file))
        return dict()
    if os.path.getsize(tree_file) == 0:
        logging.error(
            "Specified tree file {} is found but is empty".format(tree_file))
        return dict()

    # Load a tree structure from a newick file.
    t = Tree(tree_file, format=ete_format)

    logging.info("Read {} samples comprising {} nodes from tree {}".format(
        len(t.get_leaf_names()), len(t.get_descendants()), tree_file))

    #Need this otherwise groups derived from the tree are inaccurate
    if resolve_polytomy:
        logging.info(
            "Resolving polytomies through forced bifurcation {}".format(
                tree_file))
        t.resolve_polytomy()

    #Ladderize tree for aesthetics
    if ladderize:
        logging.info("Ladderizing tree {}".format(tree_file))
        t.ladderize()

    #Rooting tree based on user criteria
    if set_root:
        if method == 'midpoint':
            logging.info("Setting root based on midpoint rooting from ETE3")
            root = t.get_midpoint_outgroup()
            t.set_outgroup(root)
        elif method == 'outgroup':
            if outgroup == '':
                logging.error(
                    "User selected outgroup rooting but did not provide an outgroup, please refer to the documentation for setting an outgroup"
                )
                return None
            else:
                #if outgroup label not in the tree, return an error to the user
                if t.search_nodes(name=outgroup):
                    t.set_outgroup(outgroup)
                else:
                    logging.error(
                        "Specified outgroup not found in tree, please check the name and try again"
                    )
                    return None

    sample_list = t.get_leaf_names()

    #label all internal nodes
    node_id = 0
    for node in t.traverse("preorder"):
        if node.name == '':
            while node_id in sample_list:
                node_id += 1
            node.name = str(node_id)
            node_id += 1

    num_samples = len(t.children[0].get_tree_root())
    num_nodes = len(t.get_descendants())
    is_rooted = len(t.get_tree_root().children) == 2
    root_id = t.get_tree_root().name
    max_node, max_dist = t.get_farthest_leaf()
    logging.info("Read {} samples comprising {} nodes from tree {}:\n".format(
        num_samples, num_nodes, tree_file, sample_list))
    logging.info("Most distant sample id {}".format(max_node.name))
    logging.info("Max Distance {}".format(max_dist))
    if is_rooted:
        for node in t.traverse("levelorder"):
            if node.is_leaf():
                root_id = node.name
                break
        logging.info("Tree is rooted on sample {}".format(root_id))
    else:
        logging.error(
            "Error the tree is unrooted, you must either specify the root using an outgroup or midpoint rooting"
        )

    return t