def rename_model(target, model, accelerated_genomes):
    """Iteratively rename each ancestor of accelerated_genomes, walking down the tree to each ancestor"""
    new_model = os.path.join(target.getGlobalTempDir(),
                             'region_specific_conserved_subtree.mod')
    lines = open(model).readlines()
    t = Tree(lines[-1].split('TREE: ')[1], format=1)
    # this model may not have all of the genomes, if they were not aligned in this region
    accelerated_genomes = list(
        set(t.get_leaf_names()) & set(accelerated_genomes))
    if len(accelerated_genomes) > 1:
        anc = t.get_common_ancestor(accelerated_genomes)
        nodes = anc.get_descendants()
        leaves = anc.get_leaves()
        internal_nodes = [x for x in nodes if x not in leaves]
        for n in [anc] + internal_nodes:
            oldest_name = [x.name for x in n.get_children() if x.name != '1']
            if len(oldest_name) == 1:
                n.name = oldest_name[0] + '_Anc'
            else:
                n.name = '_'.join(oldest_name)
            with open(new_model, 'w') as outf:
                for l in lines[:-1]:
                    outf.write(l)
                outf.write('TREE: ' + t.write(format=1) + '\n')
            yield n.name, new_model
            n.name = '1'
    else:  # only one accelerated genome here -- get common ancestor above will return root node
        with open(new_model, 'w') as outf:
            for l in lines[:-1]:
                outf.write(l)
            outf.write('TREE: ' + t.write(format=1) + '\n')
        yield accelerated_genomes[0], new_model
示例#2
0
def generate_subfams_treestrat(fin_seqs, dio_tmp, threads):
    """Generates two subfamilies for a gene family.
    
    Generates subfamilies using a tree-based strategy. 
    
    Args:
        fin_seqs (chr): Path to fasta file with sequences of genes.
        dio_tmp (chr): Path to folder to store temporary files in.
        threads (int): Number of threads to use.
        
    Returns:
        A list of two elements: the gene ids for subfamily one and the gene ids
            for subfamily two. 
    """
    run_mafft(fin_seqs, f"{dio_tmp}/seqs.aln", threads)
    # infer tree with iqtree
    run_iqtree(f"{dio_tmp}/seqs.aln", f"{dio_tmp}/tree", threads, ["-m", "LG"])
    tree = Tree(f"{dio_tmp}/tree/tree.treefile")
    # midpoint root the tree and split in two
    midoutgr = tree.get_midpoint_outgroup()
    # if midoutgr is None:
    #     logging.info(f"{family}: failed to midpoint root - moving on")
    #     return(pangenome)
    genes_subfam1 = midoutgr.get_leaf_names()
    midoutgr.detach()
    genes_subfam2 = tree.get_leaf_names()
    return ([genes_subfam1, genes_subfam2])
示例#3
0
def run_rescale(prefix, tree, data, n_proc=5):
    branches = {}
    cnt = 0
    for phy, weights, asc, invariants in data:
        cnt += sum(invariants.values())
        for fname in glob.glob('RAxML_*.{0}'.format(prefix)):
            os.unlink(fname)
        if asc is None:
            cmd = '{0} -m GTR{4} -n {1} -t {7} -f e -D -s {2} -a {3} -T {5} -p {6} --no-bfgs'.format(
                raxml, prefix, phy, weights, 'GAMMA', n_proc, rint, tree)
        else:
            cmd = '{0} -m ASC_GTR{5} -n {1} -t {8} -f e -D -s {2} -a {3} -T {6} -p {7} --asc-corr stamatakis --no-bfgs -q {4}'.format(
                raxml, prefix, phy, weights, asc, 'GAMMA', n_proc, rint, tree)
        run = Popen(cmd.split())
        run.communicate()

        tre = Tree('RAxML_result.{0}'.format(prefix), format=0)
        with open(phy + '.subtree', 'w') as fout:
            fout.write(tre.write(format=0) + '\n')
        for node in tre.get_descendants('postorder'):
            if node.is_leaf():
                node.d = [node.name]
            else:
                node.d = [n for c in node.children for n in c.d]
            key = tuple(sorted(node.d))
            if key not in branches:
                branches[key] = [node.dist]
            else:
                branches[key].append(node.dist)

        for fn in glob.glob('RAxML_*.{0}'.format(prefix)) + [
                phy, phy + '.reduced', weights, asc
        ]:
            try:
                os.unlink(fn)
            except:
                pass

    tre = Tree(tree, format=1)
    leaves = set(tre.get_leaf_names())
    for node in tre.get_descendants('postorder'):
        if node.is_leaf():
            node.d = [node.name]
        else:
            node.d = [n for c in node.children for n in c.d]
        key1 = tuple(sorted(node.d))
        key2 = tuple(sorted(leaves - set(node.d)))
        if key1 in branches:
            node.dist = np.mean(branches[key1])
        elif key2 in branches:
            node.dist = np.mean(branches[key2])
        else:
            node.dist = 0.
        if -0.5 < node.dist * cnt < 0.5:
            node.dist = 0.0

    fname = '{0}.unrooted.nwk'.format(prefix)
    tre.write(outfile=fname, format=0)
    return fname
示例#4
0
def extract_model_tree(model_file):
    """
    Extracts the tree from the model file.
    """
    # the last line of a model file is the tree
    lines = open(model_file).readlines()
    l = lines[-1].split("TREE: ")[1]
    model_tree = Tree(l, format=1)
    return set(model_tree.get_leaf_names())
def conserved_model_contains_sufficient_outgroups(model,
                                                  accelerated_genomes,
                                                  outgroup_genomes,
                                                  percent_outgroups=0.5):
    """makes sure that this region, when extracted, has at least percent_outgroups present"""
    lines = open(model).readlines()
    t = Tree(lines[-1].split('TREE: ')[1], format=1)
    outgroup_nodes = set(t.get_leaf_names()) - set(accelerated_genomes)
    return format_ratio(len(outgroup_nodes),
                        len(outgroup_genomes)) >= percent_outgroups
示例#6
0
def main():
    if args.exclpops is not None:
        excluded_populations = args.exclpops.split(',')
    else:
        excluded_populations = []
    if args.exclindivs is not None:
        excluded_individuals = args.exclindivs.split(',')
    else:
        excluded_individuals = []

    # i = 0

    with gzip.open(args.input_file_name, 'rb') as input_file:
        with gzip.open(args.output_file_name, 'wb') as output_file:
            header = True
            for line in input_file:

                fields = [x.decode() for x in line.split()]

                if header:
                    header = False
                    fields += [
                        'pruned_tmrca', 'pruned_tmrca_half', 'pruned_coal_half'
                    ]
                    s = '\t'.join(fields) + '\n'
                    output_file.write(s.encode())
                    continue

                tree = Tree(fields[32])  #.decode())

                included_leaves = list()
                for leaf in tree.get_leaves():
                    if not any(pop in leaf.name for pop in excluded_populations) \
                        and not any(indiv in leaf.name for indiv in excluded_individuals):
                        # included_leaves.append(leaf)
                        included_leaves.append(leaf.name)

                #tree.prune(included_leaves, preserve_branch_length=True)
                prune(tree, included_leaves)

                # hack to ensure there is no nondicotomic node under the root:
                if len(tree.children) == 1 and not tree.children[0].is_leaf():
                    tree.children[0].delete(preserve_branch_length=True)

                assert set(tree.get_leaf_names()) == set(included_leaves)

                # if not node.is_leaf() and len(node.children) == 1 and not node.children[0].is_leaf():
                #     node.children[0].delete(preserve_branch_length=True)

                tmrca, tmrca_half, coal_half = tmrca_stats(tree)

                fields += [str(tmrca), str(tmrca_half), str(coal_half)]
                s = '\t'.join(fields) + '\n'
                output_file.write(s.encode())
示例#7
0
def refine_tree(tree_file,indir=None):
    text = open(tree_file).read()
    new_text = re.sub("\[[0-9]+\]",'',text)
    new_text = new_text.replace('OROOT','')
    t = Tree(new_text)
    right_ordered_names = t.get_leaf_names()
    if indir is not None and exists(indir):
        right_ordered_names = [rn
                               for rn in right_ordered_names
                               if exists(join(indir,f"{rn}.fna"))]
    return right_ordered_names
def TreeDataComparison(filename):

    Specieslist = []
    nodename = []
    Truesearchlist = []

    #parse the SpeciesList.txt
    with open(filename, 'r') as f:

        reader = csv.reader(f, delimiter=',')
        for row in reader:

            #matching data with export of data from Navicat
            SpeciesUID = row[0]
            Species_name = row[1]
            Newick_Formatted_Species = row[2]
            Species_commom_name = row[3]
            Highest_gold_status = row[4]
            Study_ID = row[5]
            Gene_count = row[6]
            Ensemble_accession = row[7]
            Ensembl_DB = row[8]
            Gene_build_method = row[9]
            TaxonID = row[10]
            inBiomart = row[11]

            #building a list of the species names with no repeats, to search against the tree
            if Newick_Formatted_Species in Specieslist:
                pass

            else:
                Specieslist.append(Newick_Formatted_Species)

        #Species list from previous output is saved into a list in Newick Format
        c = Specieslist

        #read in the tree
        t = Tree("specieslistfortree2realtree.nwk", format=1)

        #turn the node names into a species list (including leaf/branch names in the form of numbers)
        leaf = t.get_leaf_names(is_leaf_fn=None)

        for x in leaf:
            nodename.append(x)

        for x in c:
            if x in nodename:
                Truesearchlist.append(x)

            else:
                pass

        return Truesearchlist
示例#9
0
 def cmp_file_content(filepath: PathLike) -> bool:
     """Returns True if reference and target `filepath` differ, False otherwise."""
     ref_filepath = str(self.dir_cmp.ref_path / filepath)
     target_filepath = str(self.dir_cmp.target_path / filepath)
     # If files are newick format, the newick trees need to be read and compared
     if ref_filepath.endswith(('.nw', '.nwk', '.newick', '.nh')):
         ref_tree = Tree(ref_filepath, format=5)
         target_tree = Tree(target_filepath, format=5)
         # Check the sum of the distances between each node
         ref_sum = 0
         target_sum = 0
         for leaf in ref_tree:
             ref_sum += leaf.get_distance(ref_tree)
         for leaf in target_tree:
             target_sum += leaf.get_distance(target_tree)
         if ref_sum != target_sum:
             return ref_sum != target_sum
         # Check the leaves all match
         ref_leaves = ref_tree.get_leaf_names()
         target_leaves = target_tree.get_leaf_names()
         return sorted(ref_leaves) != sorted(target_leaves)
     return not filecmp.cmp(ref_filepath, target_filepath)
示例#10
0
def majority_tree(species_tree, node_num, phyparts_root):

    num_concord = sum([
        1
        for line in open("{}.concord.node.{}".format(phyparts_root, node_num))
    ])
    png_fn = "node_{}_speciestree.png".format(node_num, num_concord)
    for line in open(phyparts_root + ".node.key"):
        node = int(line.split()[0])
        if node == node_num:
            subtree = Tree(line.rstrip().split()[1] + ";")
            subtree_bipart = subtree.get_leaf_names()
            render_tree(species_tree, subtree_bipart, num_concord, png_fn)
def main(arg1,arg2):
   
    tree1=Tree(arg1)
    
    tree2=Tree(arg2)
        
    node_midpoint = tree1.get_leaf_names()[0]
   
    tree1.set_outgroup(node_midpoint)
    
    tree2.set_outgroup(node_midpoint)
    
    
    t1, tree2=tree2.get_tree_root().children
    t1, tree1=tree1.get_tree_root().children
    count = 0
    tree1_order=dfs_assign([tree1],tree2)

    Num_splits1=0
    Num_splits2=0
    Num_shared=0
    shared=dict()
    for node in tree1_order:
       
        if(node.is_leaf()==False):
            Num_splits1+=1
            subtree=node.get_leaf_names()
            cmin=min(subtree)
            cmax=max(subtree)
            if((node.is_root()==False)):
                shared["["+str(cmin)+":"+str(cmax)+"]"]=1
    
    for node in dfs_original([tree2]):
        if(node.is_leaf()==False):
            Num_splits2+=1
          
            size=0
            subtree=node.get_leaf_names()
            cmin=min(subtree)
            cmax=max(subtree)
            size=len(subtree)  
            if(size==(int(cmax)-int(cmin)+1)):
               if("["+str(cmin)+":"+str(cmax)+"]" in shared):
                   Num_shared+=1

   
    rf_dist=Num_splits1+Num_splits2-(2*Num_shared)
   

    return rf_dist
示例#12
0
文件: plot.py 项目: gamcil/fungphy
def add_section_annotations(tree: Tree) -> None:
    """Annotates taxonomic sections.

    Pretty hacky. Finds first common ancestor of leaf nodes per section,
    then sets a bgcolor. If a section contains a single node, then only
    that node is styled. Also adds a section label, but exact position
    is determined by which node gets found first using search_nodes().

    Relies on accurate section annotation - FP strains were set to Talaromyces
    which breaks this.
    """
    leaves = tree.get_leaf_names()
    sections = defaultdict(list)
    for strain in session.query(Strain).filter(Strain.id.in_(leaves)):
        if "FP" in strain.species.epithet:
            continue
        sections[strain.species.section.name].append(str(strain.id))

    index = 0
    colours = [
        "LightSteelBlue",
        "Moccasin",
        "DarkSeaGreen",
        "Khaki",
        "LightSalmon",
        "Turquoise",
        "Thistle"
    ]

    for section, ids in sections.items():
        # Find MRCA and set bgcolor of its node
        style = NodeStyle()
        style["bgcolor"] = colours[index]
        if len(ids) == 1:
            node = tree.search_nodes(name=ids[0])[0]
        else:
            node = tree.get_common_ancestor(*ids)
        node.set_style(style)

        # Grab first node found in this section, and add section label
        node = tree.search_nodes(name=ids[0])[0]
        face = faces.TextFace(section, fsize=20)
        node.add_face(face, column=1, position="aligned")

        # Wraparound colour scheme
        index += 1
        if index > len(colours) - 1:
            index = 0
示例#13
0
def extract_ss(input_path, suffix, tree_file):
    tree = Tree(tree_file, format=1)
    leaves_set = set(tree.get_leaf_names())
    msa = SeqGroup(input_path.alignment, "fasta")
    path_argv = [input_path._version, input_path._dataset + suffix]
    output_path = common.Paths(path_argv, 0)
    data_versioning.setup_new_dataset(output_path)
    new_msa = SeqGroup()
    for entry in msa.iter_entries():
        label = entry[0]
        sequence = entry[1]
        if (label in leaves_set):
            new_msa.set_seq(label, sequence)
    open(output_path.alignment, "w").write(new_msa.write(format="fasta"))
    shutil.copy(input_path.duplicates_json, output_path.duplicates_json)
    shutil.copy(input_path.outgroups_file, output_path.outgroups_file)
示例#14
0
def _create_tree (tree,fasta,out,color):
    seqs = SeqGroup(fasta, format="fasta")
    t = Tree(tree)
    colors = _parse_color_file(color)
    node_names = t.get_leaf_names()
    for name in node_names:
        seq = seqs.get_seq(name)
        seqFace = SeqMotifFace(seq, seq_format="()")
        node = t.get_leaves_by_name(name)
        for i in range(0,len(node)):
            if name in colors:
                ns = NodeStyle()
                ns['bgcolor'] = colors[name]
                node[i].set_style(ns)
            node[i].add_face(seqFace,0,'aligned')
    t.render(out)
def prune_species_tree(gene_tree,
                       cached_species_tree=None,
                       keep_polytomies=False):

    gTree = Tree(gene_tree)

    #species reading

    #leaf names should be of the type [speciesID_ProteinName]
    leaf_names = gTree.get_leaf_names()

    species_list = {x.split('_')[0] for x in leaf_names}
    species_list = list(species_list)

    species_ids = {''.join(filter(str.isdigit, x)): x for x in species_list}

    #big species tree
    if cached_species_tree:
        s = cached_species_tree
    else:
        s = Tree(EGGNOGv4_SPECIES_TREE)

    #get lca for core
    common_ancestor = s.get_common_ancestor(list(species_ids.keys())).copy()

    #prune to subset
    # common_ancestor.prune(species_ids) # slower method
    leaves = {x.name: x for x in common_ancestor.get_leaves()}
    to_remove = leaves.keys() - species_ids.keys()
    for species_id in to_remove:
        if species_id in leaves:
            leaves[species_id].delete()
    assert (len(common_ancestor.get_leaf_names()) == len(species_ids))

    #binarize
    if not keep_polytomies:
        common_ancestor.resolve_polytomy(recursive=True)

    #change names
    for leaf in common_ancestor.get_leaves():
        leaf.name = species_ids[leaf.name]

    #write out reconciliation_job
    species_nw = common_ancestor.write(format=5)

    return species_nw
示例#16
0
def extract_ss(input_path, suffix, tree_file):
    print(
        "Extracting alignment generated with the support selection tree thinning technique..."
    )
    tree = Tree(tree_file, format=1)
    leaves_set = set(tree.get_leaf_names())
    msa = SeqGroup(input_path.alignment, "fasta")
    path_argv = [input_path._version, input_path._dataset + suffix]
    output_path = common.Paths(path_argv, 0)
    data_versioning.setup_new_dataset(output_path)
    new_msa = SeqGroup()
    for entry in msa.iter_entries():
        label = entry[0]
        sequence = entry[1]
        if (label in leaves_set):
            new_msa.set_seq(label, sequence)
    open(output_path.alignment, "w").write(new_msa.write(format="fasta"))
    shutil.copy(input_path.duplicates_json, output_path.duplicates_json)
    shutil.copy(input_path.outgroups_file, output_path.outgroups_file)
    print("New version of the snapshot: " + output_path.path)
示例#17
0
def main(mcmc_out_tree, out_table):
    mcmc_out_tree_text = open(mcmc_out_tree)
    tree = None
    for row in mcmc_out_tree_text:
        if row.strip().startswith('UTREE 1 ='):
            t = row.split('UTREE 1 =')[1].strip('\n')
            t = re.sub('\[&95%HPD=.*?\]', sub_for, t)
            t = t.replace(' ', '')
            tree = Tree(t, format=1)
    if glob(join(dirname(mcmc_out_tree), '*.out')):
        outfile = glob(join(dirname(mcmc_out_tree), '*.out'))[0]
        rename_tree = get_node_name(outfile)
    else:
        rename_tree = None

    rows = ["\t".join(["node name", "Posterior mean time", "CIs"])]

    count = len(tree.get_leaf_names()) + 1
    for n in tree.traverse():
        if not n.is_leaf():
            dates = n.name
            if rename_tree is None:
                n.name = 'I%s' % count
            else:
                n.name = 't_n%s' % rename_tree.get_common_ancestor(
                    n.get_leaf_names()).name
            n.add_features(ages=dates, )
            count += 1

            rows.append("\t".join([
                n.name,
                str(n.get_distance(tree.get_leaves()[0])),
                str(dates)
            ]))
    if not exists(dirname(process_path(out_table))):
        os.makedirs(dirname(process_path(out_table)))
    with open(out_table, 'w') as f1:
        f1.write('\n'.join(rows))
示例#18
0
    def _parse_tree_tag(self, dir_path):
        self.true_ancestor_names = {}
        with open(os.path.join(dir_path, self.name_tool, "tree_tag.txt")) as f:
            input_tree = Tree(f.readline().strip("\n\t"), 8)
            self.all_species = set(input_tree.get_leaf_names())
            for node in list(input_tree.traverse("preorder"))[1:]:
                if node.name not in self.all_species:
                    if not node.is_leaf():
                        left_split = list(node.children[0].get_leaf_names())
                        left_split.sort()
                        right_split = list(node.children[1].get_leaf_names())
                        right_split.sort()
                        ancestor_split = list(self.all_species.difference(set(left_split).union(set(right_split))))
                        ancestor_split.sort()

                        result = [left_split, right_split, ancestor_split]
                        result = ["".join(str(result[0])),
                                  "".join(str(result[1])),
                                  "".join(str(result[2]))]
                        result.sort()

                        branch = " ".join(result)
                        self.true_ancestor_names[branch] = node.name
示例#19
0
    def get_newick_tree(self):
        """
        Gets the config file field of the Newick tree.
        Checks and exits if the species' names in the Newick tree contain illegal characters (underscore or spaces).

        :return tree_string: the tree object by ete3
        """
        tree_string = self.config.get("SPECIES", "newick_tree")
        if not (tree_string.endswith(';')):
            tree_string += ";"
        if tree_string == "();" or tree_string == ";":
            logging.error(
                'Field "newick_tree" in configuration file is empty, please fill in'
            )
            sys.exit(1)
        try:
            tree = Tree(tree_string)
        except Exception:
            logging.error(
                'Unrecognized format for field "newick_tree" in configuration file (for example, parentheses do not match)'
            )
            sys.exit(1)

        # Check if species' informal names contain illegal characters (underscore or spaces)
        species_illegal_char = []
        for informal_name in tree.get_leaf_names():
            if "_" in informal_name or " " in informal_name:
                species_illegal_char.append(informal_name)
        if len(species_illegal_char) != 0:
            logging.error(
                f"Informal species' names must not contain any spaces or underscores. Please change the following names in the configuration file:"
            )
            for informal_name in species_illegal_char:
                logging.error(f"- {informal_name}")
            sys.exit(1)
        return tree
示例#20
0
class Species:
    def __init__(self,
                 path,
                 max_unknowns=200,
                 contigs=3.0,
                 assembly_size=3.0,
                 mash=3.0,
                 assembly_summary=None,
                 processes=1):
        """Represents a collection of genomes in `path`

        :param path: Path to the directory of related genomes you wish to analyze.
        :param max_unknowns: Number of allowable unknown bases, i.e. not [ATCG]
        :param contigs: Acceptable deviations from median number of contigs
        :param assembly_size: Acceptable deviations from median assembly size
        :param mash: Acceptable deviations from median MASH distances
        :param assembly_summary: a pandas DataFrame with assembly summary information
        """
        self.max_unknowns = max_unknowns
        self.contigs = contigs
        self.assembly_size = assembly_size
        self.mash = mash
        self.assembly_summary = assembly_summary
        self.deviation_values = [max_unknowns, contigs, assembly_size, mash]
        self.ncpus = processes
        self.path = os.path.abspath(path)
        self.name = os.path.basename(os.path.normpath(path))
        self.log = logbook.Logger(self.name)
        self.qc_dir = os.path.join(self.path, "qc")
        self.label = '-'.join(map(str, self.deviation_values))
        self.qc_results_dir = os.path.join(self.qc_dir, self.label)
        self.passed_dir = os.path.join(self.qc_results_dir, "passed")
        self.stats_path = os.path.join(self.qc_dir, 'stats.csv')
        self.nw_path = os.path.join(self.qc_dir, 'tree.nw')
        self.dmx_path = os.path.join(self.qc_dir, 'dmx.csv')
        self.failed_path = os.path.join(self.qc_results_dir, "failed.csv")
        self.tree_img = os.path.join(self.qc_results_dir, "tree.svg")
        self.summary_path = os.path.join(self.qc_results_dir, "summary.txt")
        self.allowed_path = os.path.join(self.qc_results_dir, "allowed.p")
        self.paste_file = os.path.join(self.qc_dir, 'all.msh')
        # Figure out if defining these as None is necessary
        self.tree = None
        self.stats = None
        self.dmx = None
        if os.path.isfile(self.stats_path):
            self.stats = pd.read_csv(self.stats_path, index_col=0)
        if os.path.isfile(self.nw_path):
            self.tree = Tree(self.nw_path, 1)
        if os.path.isfile(self.failed_path):
            self.failed_report = pd.read_csv(self.failed_path, index_col=0)
        if os.path.isfile(self.dmx_path):
            try:
                self.dmx = pd.read_csv(self.dmx_path, index_col=0, sep="\t")
                self.log.info("Distance matrix read succesfully")
            except pd.errors.EmptyDataError:
                self.log.exception()
        self.metadata_path = os.path.join(self.qc_dir,
                                          "{}_metadata.csv".format(self.name))
        try:
            self.metadata_df = pd.read_csv(self.metadata_path,
                                           index_col="accession")
        except FileNotFoundError:
            self.metadata_df = pd.DataFrame(columns=["accession"])
        self.criteria = ["unknowns", "contigs", "assembly_size", "distance"]
        self.tolerance = {
            "unknowns": max_unknowns,
            "contigs": contigs,
            "assembly_size": assembly_size,
            "distance": mash
        }
        self.passed = self.stats
        self.failed = {}
        self.med_abs_devs = {}
        self.dev_refs = {}
        self.allowed = {"unknowns": max_unknowns}
        self.colors = {
            "unknowns": "red",
            "contigs": "green",
            "distance": "purple",
            "assembly_size": "orange"
        }
        self.genomes = [
            Genome.Genome(genome, self.assembly_summary)
            for genome in self.genome_paths
        ]
        self.assess_tree()

    def __str__(self):
        self.message = [
            "Species: {}".format(self.name),
            "Maximum Unknown Bases:  {}".format(self.max_unknowns),
            "Acceptable Deviations,", "Contigs, {}".format(self.contigs),
            "Assembly Size, {}".format(self.assembly_size),
            "MASH: {}".format(self.mash)
        ]
        return '\n'.join(self.message)

    def assess(f):
        # TODO: This can have a more general application if the pickling
        # functionality is implemented elsewhere
        @functools.wraps(f)
        def wrapper(self):
            try:
                assert self.stats is not None
                assert os.path.isfile(self.allowed_path)
                assert (sorted(self.genome_ids().tolist()) == sorted(
                    self.stats.index.tolist()))
                self.complete = True
                with open(self.allowed_path, 'rb') as p:
                    self.allowed = pickle.load(p)
                self.log.info('Already complete')
            except AssertionError:
                self.complete = False
                f(self)

        return wrapper

    def assess_tree(self):
        try:
            assert self.tree is not None
            assert self.stats is not None
            leaf_names = [
                re.sub(".fasta", "", i) for i in self.tree.get_leaf_names()
            ]
            assert (sorted(leaf_names) == sorted(self.stats.index.tolist()) ==
                    sorted(self.genome_ids().tolist()))
            self.tree_complete = True
            self.log.info("Tree already complete")
        except AssertionError:
            self.tree_complete = False

    @property
    def genome_paths(self, ext="fasta"):
        # Why doesn't this work when importing at top of file?
        """Returns a generator for every file ending with `ext`

        :param ext: File extension of genomes in species directory
        :returns: Generator of Genome objects for all genomes in species dir
        :rtype: generator
        """
        return [
            os.path.join(self.path, genome) for genome in os.listdir(self.path)
            if genome.endswith(ext)
        ]

    # @property
    # def genomes(self):
    #     """Returns a generator for every file ending with `ext`

    #     :param ext: File extension of genomes in species directory
    #     :returns: Generator of Genome objects for all genomes in species dir
    #     :rtype: generator
    #     """
    #     return (Genome.Genome(genome, self.assembly_summary) for genome in self.genome_paths)

    @property
    def total_genomes(self):
        return len(list(self.genomes))

    def sketches(self):
        return (i.msh for i in self.genomes)

    def genome_ids(self):
        ids = [i.name for i in self.genomes]
        return pd.Index(ids)

    # may be redundant. see genome_ids attrib
    @property
    def accession_ids(self):
        ids = [
            i.accession_id for i in self.genomes if i.accession_id is not None
        ]
        return ids

    def mash_paste(self):
        if os.path.isfile(self.paste_file):
            os.remove(self.paste_file)
        sketches = os.path.join(self.qc_dir, "*msh")
        cmd = "mash paste {} {}".format(self.paste_file, sketches)
        Popen(cmd, shell="True", stderr=DEVNULL).wait()
        self.log.info("MASH paste completed")
        if not os.path.isfile(self.paste_file):
            self.log.error("MASH paste failed")
            self.paste_file = None

    def mash_dist(self):
        cmd = "mash dist -p {} -t '{}' '{}' > '{}'".format(
            self.ncpus, self.paste_file, self.paste_file, self.dmx_path)
        Popen(cmd, shell="True", stderr=DEVNULL).wait()
        self.log.info("MASH distance completed")
        self.dmx = pd.read_csv(self.dmx_path, index_col=0, sep="\t")
        # Make distance matrix more readable
        names = [os.path.splitext(i)[0].split('/')[-1] for i in self.dmx.index]
        self.dmx.index = names
        self.dmx.columns = names
        self.dmx.to_csv(self.dmx_path, sep="\t")
        self.log.info("dmx.csv created")

    def sketch_genomes(self):
        """Sketch all genomes"""
        with Pool(ncpus=self.ncpus) as pool:
            self.log.info("{} cpus in pool".format(pool.ncpus))
            pool.map(Genome.sketch_genome, self.genome_paths)
        self.log.info("All genomes sketched")

    def get_tree(self):
        # Use decorator instead of if statement
        if self.tree_complete is False:
            from ete3.coretype.tree import TreeError
            import numpy as np
            # import matplotlib as mpl
            # mpl.use('TkAgg')
            from skbio.tree import TreeNode
            from scipy.cluster.hierarchy import weighted
            ids = ['{}.fasta'.format(i) for i in self.dmx.index.tolist()]
            triu = np.triu(self.dmx.as_matrix())
            hclust = weighted(triu)
            t = TreeNode.from_linkage_matrix(hclust, ids)
            nw = t.__str__().replace("'", "")
            self.tree = Tree(nw)
            # midpoint root tree
            try:
                self.tree.set_outgroup(self.tree.get_midpoint_outgroup())
            except TreeError as e:
                self.log.exception()
            self.tree.write(outfile=self.nw_path)

    def get_stats(self):
        """Get stats for all genomes. Concat the results into a DataFrame"""
        dmx_mean = [self.dmx.mean()] * len(self.genome_paths)
        with Pool(ncpus=self.ncpus) as pool:
            results = pool.map(Genome.mp_stats, self.genome_paths, dmx_mean)
        self.stats = pd.concat(results)
        self.stats.to_csv(self.stats_path)
        self.log.info("Generated stats and wrote to disk")

    def MAD(self, df, col):
        """Get the median absolute deviation for col"""
        MAD = abs(df[col] - df[col].median()).mean()
        return MAD

    def MAD_ref(MAD, tolerance):
        """Get the reference value for median absolute deviation"""
        dev_ref = MAD * tolerance
        return dev_ref

    def bound(df, col, dev_ref):
        lower = df[col].median() - dev_ref
        upper = df[col].median() + dev_ref
        return lower, upper

    def filter_unknown_bases(self):
        """Filter out genomes with too many unknown bases."""
        self.failed["unknowns"] = self.stats.index[
            self.stats["unknowns"] > self.tolerance["unknowns"]]
        self.passed = self.stats.drop(self.failed["unknowns"])
        self.log.info("Analyzed unknowns")

    def check_passed_count(f):
        """
        Count the number of genomes in self.passed.
        Commence with filtering only if self.passed has more than five genomes.
        """
        @functools.wraps(f)
        def wrapper(self, *args):
            if len(self.passed) > 5:
                f(self, *args)
            else:
                self.allowed[args[0]] = ''
                self.failed[args[0]] = ''
                self.log.info("Stopped filtering after {}".format(f.__name__))

        return wrapper

    @check_passed_count
    def filter_contigs(self, criteria):
        """
        Only look at genomes with > 10 contigs to avoid throwing off the
        median absolute deviation.
        Median absolute deviation - Average absolute difference between
        number of contigs and the median for all genomes
        Extract genomes with < 10 contigs to add them back in later.
        Add genomes with < 10 contigs back in
        """
        eligible_contigs = self.passed.contigs[self.passed.contigs > 10]
        not_enough_contigs = self.passed.contigs[self.passed.contigs <= 10]
        # TODO Define separate function for this
        med_abs_dev = abs(eligible_contigs - eligible_contigs.median()).mean()
        self.med_abs_devs["contigs"] = med_abs_dev
        # Define separate function for this
        # The "deviation reference"
        dev_ref = med_abs_dev * self.contigs
        self.dev_refs["contigs"] = dev_ref
        self.allowed["contigs"] = eligible_contigs.median() + dev_ref
        self.failed["contigs"] = eligible_contigs[
            abs(eligible_contigs - eligible_contigs.median()) > dev_ref].index
        eligible_contigs = eligible_contigs[
            abs(eligible_contigs - eligible_contigs.median()) <= dev_ref]
        eligible_contigs = pd.concat([eligible_contigs, not_enough_contigs])
        eligible_contigs = eligible_contigs.index
        self.passed = self.passed.loc[eligible_contigs]
        self.log.info("Analyzed contigs")

    @check_passed_count
    def filter_MAD_range(self, criteria):
        """
        Filter based on median absolute deviation.
        Passing values fall within a lower and upper bound.
        """
        # Get the median absolute deviation
        med_abs_dev = abs(self.passed[criteria] -
                          self.passed[criteria].median()).mean()
        dev_ref = med_abs_dev * self.tolerance[criteria]
        lower = self.passed[criteria].median() - dev_ref
        upper = self.passed[criteria].median() + dev_ref
        allowed_range = (str(int(x)) for x in [lower, upper])
        allowed_range = '-'.join(allowed_range)
        self.allowed[criteria] = allowed_range
        self.failed[criteria] = self.passed[
            abs(self.passed[criteria] -
                self.passed[criteria].median()) > dev_ref].index
        self.passed = self.passed[abs(
            self.passed[criteria] - self.passed[criteria].median()) <= dev_ref]
        self.log.info("Filtered based on median absolute deviation range")

    @check_passed_count
    def filter_MAD_upper(self, criteria):
        """
        Filter based on median absolute deviation.
        Passing values fall under the upper bound.
        """
        # Get the median absolute deviation
        med_abs_dev = abs(self.passed[criteria] -
                          self.passed[criteria].median()).mean()
        dev_ref = med_abs_dev * self.tolerance[criteria]
        upper = self.passed[criteria].median() + dev_ref
        self.failed[criteria] = self.passed[
            self.passed[criteria] > upper].index
        self.passed = self.passed[self.passed[criteria] <= upper]
        upper = "{:.4f}".format(upper)
        self.allowed[criteria] = upper
        self.log.info("Filtered based on MAD upper bound")

    def base_node_style(self):
        from ete3 import NodeStyle, AttrFace
        nstyle = NodeStyle()
        nstyle["shape"] = "sphere"
        nstyle["size"] = 2
        nstyle["fgcolor"] = "black"
        for n in self.tree.traverse():
            n.set_style(nstyle)
            if re.match('.*fasta', n.name):
                nf = AttrFace('name', fsize=8)
                nf.margin_right = 150
                nf.margin_left = 3
                n.add_face(nf, column=0)
        self.log.info("Applied base node style")

    # Might be better in a layout function
    def style_and_render_tree(self, file_types=["svg"]):
        from ete3 import TreeStyle, TextFace, CircleFace
        ts = TreeStyle()
        title_face = TextFace(self.name.replace('_', ' '), fsize=20)
        title_face.margin_bottom = 10
        ts.title.add_face(title_face, column=0)
        ts.branch_vertical_margin = 10
        ts.show_leaf_name = False
        # Legend
        ts.legend.add_face(TextFace(""), column=1)
        for category in ["Allowed", "Tolerance", "Filtered", "Color"]:
            category = TextFace(category, fsize=8, bold=True)
            category.margin_bottom = 2
            category.margin_right = 40
            ts.legend.add_face(category, column=1)
        for i, criteria in enumerate(self.criteria, 2):
            title = criteria.replace("_", " ").title()
            title = TextFace(title, fsize=8, bold=True)
            title.margin_bottom = 2
            title.margin_right = 40
            cf = CircleFace(4, self.colors[criteria], style="sphere")
            cf.margin_bottom = 5
            filtered_count = len(
                list(filter(None, self.failed_report.criteria == criteria)))
            filtered = TextFace(filtered_count, fsize=8)
            filtered.margin_bottom = 5
            allowed = TextFace(self.allowed[criteria], fsize=8)
            allowed.margin_bottom = 5
            allowed.margin_right = 25
            tolerance = TextFace(self.tolerance[criteria], fsize=8)
            tolerance.margin_bottom = 5
            ts.legend.add_face(title, column=i)
            ts.legend.add_face(allowed, column=i)
            ts.legend.add_face(tolerance, column=i)
            ts.legend.add_face(filtered, column=i)
            ts.legend.add_face(cf, column=i)
        for f in file_types:
            out_tree = os.path.join(self.qc_results_dir, 'tree.{}'.format(f))
            self.tree.render(out_tree, tree_style=ts)
            self.log.info("tree.{} generated".format(f))

    def color_tree(self):
        from ete3 import NodeStyle
        self.base_node_style()
        for genome in self.failed_report.index:
            n = self.tree.get_leaves_by_name(genome + ".fasta").pop()
            nstyle = NodeStyle()
            nstyle["fgcolor"] = self.colors[self.failed_report.loc[genome,
                                                                   'criteria']]
            nstyle["size"] = 9
            n.set_style(nstyle)
        self.style_and_render_tree()

    def filter(self):
        self.filter_unknown_bases()
        self.filter_contigs("contigs")
        self.filter_MAD_range("assembly_size")
        self.filter_MAD_upper("distance")
        with open(self.allowed_path, 'wb') as p:
            pickle.dump(self.allowed, p)
            self.log.info("Pickled results of filtering")
        self.summary()
        self.write_failed_report()

    def write_failed_report(self):
        from itertools import chain
        if os.path.isfile(self.failed_path):
            os.remove(self.failed_path)
        ixs = chain.from_iterable([i for i in self.failed.values()])
        self.failed_report = pd.DataFrame(index=ixs, columns=["criteria"])
        for criteria in self.failed.keys():
            if type(self.failed[criteria]) == pd.Index:
                self.failed_report.loc[self.failed[criteria],
                                       'criteria'] = criteria
        self.failed_report.to_csv(self.failed_path)
        self.log.info("Wrote failed report")

    def summary(self):
        summary = [
            self.name, "Unknown Bases",
            "Allowed: {}".format(self.allowed["unknowns"]),
            "Tolerance: {}".format(self.tolerance["unknowns"]),
            "Filtered: {}".format(len(self.failed["unknowns"])), "\n",
            "Contigs", "Allowed: {}".format(self.allowed["contigs"]),
            "Tolerance: {}".format(
                self.tolerance["contigs"]), "Filtered: {}".format(
                    len(self.failed["contigs"])), "\n", "Assembly Size",
            "Allowed: {}".format(self.allowed["assembly_size"]),
            "Tolerance: {}".format(self.tolerance["assembly_size"]),
            "Filtered: {}".format(len(self.failed["assembly_size"])), "\n",
            "MASH", "Allowed: {}".format(self.allowed["distance"]),
            "Tolerance: {}".format(self.tolerance["distance"]),
            "Filtered: {}".format(len(self.failed["distance"])), "\n"
        ]
        summary = '\n'.join(summary)
        with open(os.path.join(self.summary_path), "w") as f:
            f.write(summary)
            self.log.info("Wrote QC summary")
        return summary

    def link_genomes(self):
        if not os.path.exists(self.passed_dir):
            os.mkdir(self.passed_dir)
        for genome in self.passed.index:
            fname = "{}.fasta".format(genome)
            src = os.path.join(self.path, fname)
            dst = os.path.join(self.passed_dir, fname)
            try:
                os.link(src, dst)
            except FileExistsError:
                pass
        self.log.info("Links created for genomes that passed QC")

    @assess
    def qc(self):
        if not os.path.isdir(self.qc_dir):
            os.mkdir(self.qc_dir)
        if not os.path.isdir(self.qc_results_dir):
            os.mkdir(self.qc_results_dir)
        self.sketch_genomes()
        self.mash_paste()
        self.mash_dist()
        self.get_stats()
        self.filter()
        self.link_genomes()
        self.get_tree()
        self.color_tree()
        self.log.info("qc command completed")

    def metadata(self):
        metadata = []
        for genome in self.genomes:
            if genome.accession_id in self.metadata_df.index:
                continue
            genome.get_metadata()
            metadata.append(genome.metadata)
        self.metadata_df = pd.concat(
            [self.metadata_df,
             pd.DataFrame(metadata).set_index("accession")])
        self.metadata_df.to_csv(self.metadata_path)
        self.log.info("Completed metadata command")
示例#21
0
###########################################################################

t = Tree('{}'.format(treedata), format =1)
circular_style = TreeStyle()
circular_style.show_branch_length = True # show branch length
circular_style.show_branch_support = True # show support
circular_style.mode = "c" # draw tree in circular mode
circular_style.scale = 100


t.render(path+"beautiful_life_tree.png", w=3000, units="mm", tree_style=circular_style)
t.render(path+"beautiful_life_tree.pdf", w=3000, units="mm", tree_style=circular_style)

species = []
for node in t.get_leaf_names():
    species.append(node)

random_nodes = []

i = 0
for i in range(0, 42):
    random_nodes.append(str(random.choice(species)))
    i = i + 1
    
""" For extreme situations use this to cleane the name and prune trees """
#clean_nodes = [re.sub(r'\n--', '', elem) for elem in random_nodes]
# new_t = t
# for item in clean_nodes:
#     print('Pruning: {}{}{}'.format("'",item,"'"))
#    new_t.prune('{}{}{}'.format("'",item,"'"))
     required=False,
     type=str,
     dest="nwk")
 parser.add_argument(
     '-t',
     '--tree',
     default=
     "../DataEmpirical/PrimatesBinaryLHTShort/rootedtree.nwk.annotated",
     required=False,
     type=str,
     dest="tree")
 args = parser.parse_args()
 tree = Tree(args.tree, format=1)
 nwk = Tree(args.nwk, format=1)
 assert (len(tree) == len(nwk))
 tree_leaf_names = set(tree.get_leaf_names())
 for leaf in nwk:
     if leaf.name not in tree_leaf_names:
         try:
             leaf.name = next(n for n in tree_leaf_names
                              if leaf.name.split("_")[0] == n.split("_")[0])
         except StopIteration:
             leaf.name = next(n for n in tree_leaf_names
                              if leaf.name.split("_")[1] == n.split("_")[1])
 assert (set(nwk.get_leaf_names()) == tree_leaf_names)
 root_age = nwk.get_closest_leaf()[1]
 df = [["Root", root_age, root_age, root_age]]
 for n in nwk.iter_descendants(strategy='postorder'):
     if not n.is_leaf():
         name = tree.get_common_ancestor(n.get_leaf_names()).name
         if n.dist <= 0:
示例#23
0
                print(e)
                logging.warning(
                    "File %s cannot be read as fasta alignment. Skipping the file"
                    % al)

    # now try to reset the leaves name in the tree
    common_name = []
    for leaf in tree:
        new_id = [x for x, y in spec_and_descr.items() if leaf.name in y]
        if new_id:
            common_name.append(new_id[0])
            leaf.add_feature("fullname", leaf.name)
            leaf.name = new_id[0]
    tree.prune(common_name)
    #print(tree.get_ascii(attributes=["name", "fullname"]))
    if len(tree) < len(spec_and_descr):
        logging.warning("The following genomes are missing in the tree:")

    for s in set.symmetric_difference(set(tree.get_leaf_names()),
                                      spec_and_descr.keys()):
        logging.warning(spec_and_descr[s])
    display_tree(tree,
                 al_per_gene,
                 al_len,
                 eddict,
                 out,
                 gcode,
                 colormap,
                 dpi=args.dpi,
                 width=args.width,
                 scale=args.scale)
示例#24
0
ancestor = t.get_common_ancestor(OUTG1,OUTG2)
t.set_outgroup( ancestor )
ts = TreeStyle()
ts.show_leaf_name = True
ts.show_branch_support = True
t.render(NEWICK+".png", w=600, units="mm",tree_style=ts)
t.write(format=0, outfile=NEWICK+".newick")


# ゲノム情報の読み込み
info = pd.read_table(g_genome, sep='\t', index_col=0)
frame = pd.DataFrame(info)


# テーブルの並び替え
strain_list = t.get_leaf_names()
SORT = frame.reindex(strain_list)
SORT.to_csv(g_genome+'.sort.txt', sep='\t')



print("OUTPUT FILES:")
print(NEWICK+".newick")
print(NEWICK+".png")
print(g_genome+'.sort.txt')

print('--------------------------------------------------------------------------------------------------')
print(t)
print('--------------------------------------------------------------------------------------------------')
print(SORT)
示例#25
0
文件: ptpllh.py 项目: zhangjiajie/PTP
class exponential_mixture:
    """ML search PTP, to use: __init__(), search() and count_species()"""
    def __init__(self, tree, sp_rate = 0, fix_sp_rate = False, max_iters = 20000, min_br = 0.0001):
        self.min_brl = min_br
        self.tree = Tree(tree, format = 1)
        self.tree.resolve_polytomy(recursive=True)
        self.tree.dist = 0.0
        self.fix_spe_rate = fix_sp_rate
        self.fix_spe = sp_rate
        self.max_logl = float("-inf") 
        self.max_setting = None
        self.null_logl = 0.0
        self.null_model()
        self.species_list = None
        self.counter = 0
        self.setting_set = set([])
        self.max_num_search = max_iters


    def null_model(self):
        coa_br = []
        all_nodes = self.tree.get_descendants()
        for node in all_nodes:
            if node.dist > self.min_brl:
                coa_br.append(node.dist)
        e1 = exp_distribution(coa_br)
        self.null_logl = e1.sum_log_l()
        return e1.rate


    def __compare_node(self, node):
        return node.dist


    def re_rooting(self):
        node_list = self.tree.get_descendants()
        node_list.sort(key=self.__compare_node)
        node_list.reverse()
        rootnode = node_list[0]
        self.tree.set_outgroup(rootnode)
        self.tree.dist = 0.0


    def comp_num_comb(self):
        for node in self.tree.traverse(strategy='postorder'):
            if node.is_leaf():
                node.add_feature("cnt", 1.0)
            else:
                acum = 1.0
                for child in node.get_children():
                    acum = acum * child.cnt
                acum = acum + 1.0
                node.add_feature("cnt", acum)
        return self.tree.cnt


    def next(self, sp_setting):
        self.setting_set.add(frozenset(sp_setting.spe_nodes))
        logl = sp_setting.get_log_l()
        if logl > self.max_logl:
            self.max_logl = logl
            self.max_setting = sp_setting
        for node in sp_setting.active_nodes:
            if node.is_leaf():
                pass
            else:
                childs = node.get_children()
                sp_nodes = []
                for child in childs:
                    sp_nodes.append(child)
                for nod in sp_setting.spe_nodes:
                    sp_nodes.append(nod)
                new_sp_setting = species_setting(spe_nodes = sp_nodes, root = sp_setting.root, sp_rate = sp_setting.spe_rate, fix_sp_rate = sp_setting.fix_spe_rate, minbr = self.min_brl)
                if frozenset(sp_nodes) in self.setting_set:
                    pass
                else:
                    self.next(new_sp_setting)


    def H0(self, reroot = True):
        self.H1(reroot)
        self.H2(reroot = False)
        self.H3(reroot = False)


    def H1(self, reroot = True):
        if reroot:
            self.re_rooting()
            
        #self.init_tree()
        sorted_node_list = self.tree.get_descendants()
        sorted_node_list.sort(key=self.__compare_node)
        sorted_node_list.reverse()
        
        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        first_setting = species_setting(spe_nodes = first_node_list, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
        last_setting = first_setting
        max_logl = last_setting.get_log_l()
        max_setting = last_setting
        
        for node in sorted_node_list:
            if node not in last_setting.spe_nodes:
                curr_sp_nodes = []
                for nod in last_setting.spe_nodes:
                    curr_sp_nodes.append(nod)
                
                chosen_branching_node = node.up #find the father of this new node
                if chosen_branching_node in last_setting.spe_nodes:
                    for nod in chosen_branching_node.get_children():
                        if nod not in curr_sp_nodes:
                            curr_sp_nodes.append(nod)
                else:
                    for nod in chosen_branching_node.get_children():
                        if nod not in curr_sp_nodes:
                            curr_sp_nodes.append(nod)
                    while not chosen_branching_node.is_root():
                        chosen_branching_node = chosen_branching_node.up
                        for nod in chosen_branching_node.get_children():
                            if nod not in curr_sp_nodes:
                                curr_sp_nodes.append(nod)
                        if chosen_branching_node in last_setting.spe_nodes:
                            break
                new_setting = species_setting(spe_nodes = curr_sp_nodes, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
                new_logl = new_setting.get_log_l()
                if new_logl> max_logl:
                    max_logl = new_logl
                    max_setting = new_setting 
                last_setting = new_setting
                
            else:
                """node already is a speciation node, do nothing"""
                pass
        
        if max_logl > self.max_logl:
            self.max_logl = max_logl
            self.max_setting = max_setting


    def H2(self, reroot = True):
        """Greedy"""
        if reroot:
            self.re_rooting()
            
        #self.init_tree()
        sorted_node_list = self.tree.get_descendants()
        sorted_node_list.sort(key=self.__compare_node)
        sorted_node_list.reverse()
        
        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        first_setting = species_setting(spe_nodes = first_node_list, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
        last_setting = first_setting
        max_logl = last_setting.get_log_l()
        max_setting = last_setting
        contin_flag = True 
        
        
        while contin_flag:
            curr_max_logl = float("-inf") 
            curr_max_setting = None
            contin_flag = False
            for node in last_setting.active_nodes:
                if node.is_leaf():
                    pass
                else:
                    contin_flag = True 
                    childs = node.get_children()
                    sp_nodes = []
                    for child in childs:
                        sp_nodes.append(child)
                    for nod in last_setting.spe_nodes:
                        sp_nodes.append(nod)
                    new_sp_setting = species_setting(spe_nodes = sp_nodes, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
                    logl = new_sp_setting.get_log_l()
                    if logl > curr_max_logl:
                        curr_max_logl = logl
                        curr_max_setting = new_sp_setting
            
            if curr_max_logl > max_logl:
                max_setting = curr_max_setting
                max_logl = curr_max_logl
            
            last_setting = curr_max_setting
            
        if max_logl > self.max_logl:
            self.max_logl = max_logl
            self.max_setting = max_setting


    def H3(self, reroot = True):
        if reroot:
            self.re_rooting()
        sorted_node_list = self.tree.get_descendants()
        sorted_node_list.sort(key=self.__compare_node)
        sorted_node_list.reverse()
        sorted_br = []
        for node in sorted_node_list:
            sorted_br.append(node.dist)
        maxlogl = float("-inf") 
        maxidx = -1
        for i in range(len(sorted_node_list))[1:]:
            l1 = sorted_br[0:i]
            l2 = sorted_br[i:]
            e1 = exp_distribution(l1)
            e2 = exp_distribution(l2)
            logl = e1.sum_log_l() + e2.sum_log_l()
            if logl > maxlogl:
                maxidx = i
                maxlogl = logl
        
        target_nodes = sorted_node_list[0:maxidx]
        
        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        first_setting = species_setting(spe_nodes = first_node_list, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
        last_setting = first_setting
        max_logl = last_setting.get_log_l()
        max_setting = last_setting
        contin_flag = True 
        target_node_cnt = 0
        while contin_flag:
            curr_max_logl = float("-inf") 
            curr_max_setting = None
            contin_flag = False
            unchanged_flag = True
            for node in last_setting.active_nodes:
                if node.is_leaf():
                    pass
                else:
                    contin_flag = True 
                    childs = node.get_children()
                    sp_nodes = []
                    flag = False
                    for child in childs:
                        if child in target_nodes:
                            flag = True
                            #target_nodes.remove(child)
                    if flag:
                        unchanged_flag = False
                        for child in childs:
                            sp_nodes.append(child)
                        for nod in last_setting.spe_nodes:
                            sp_nodes.append(nod)
                        new_sp_setting = species_setting(spe_nodes = sp_nodes, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
                        logl = new_sp_setting.get_log_l()
                        if logl > curr_max_logl:
                            curr_max_logl = logl
                            curr_max_setting = new_sp_setting
            if not unchanged_flag:
                target_node_cnt = target_node_cnt + 1
                if curr_max_logl > max_logl:
                    max_setting = curr_max_setting
                    max_logl = curr_max_logl
                last_setting = curr_max_setting
            
            if len(target_nodes) == target_node_cnt:
                contin_flag = False
            if contin_flag and unchanged_flag and last_setting!= None:
                for node in last_setting.active_nodes:
                    if node.is_leaf():
                        pass
                    else:
                        childs = node.get_children()
                        sp_nodes = []
                        for child in childs:
                            sp_nodes.append(child)
                        for nod in last_setting.spe_nodes:
                            sp_nodes.append(nod)
                        new_sp_setting = species_setting(spe_nodes = sp_nodes, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
                        logl = new_sp_setting.get_log_l()
                        if logl > curr_max_logl:
                            curr_max_logl = logl
                            curr_max_setting = new_sp_setting
                if curr_max_logl > max_logl:
                    max_setting = curr_max_setting
                    max_logl = curr_max_logl
                last_setting = curr_max_setting
                
        if max_logl > self.max_logl:
            self.max_logl = max_logl
            self.max_setting = max_setting


    def Brutal(self, reroot = False):
        if reroot:
            self.re_rooting()
        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        num_s = self.comp_num_comb()
        if num_s > self.max_num_search:
            print("Too many search iterations: " + repr(num_s) + ", using H0 instead!!!")
            self.H0(reroot = False)
        else:
            first_setting = species_setting(spe_nodes = first_node_list, root = self.tree, sp_rate = self.fix_spe, fix_sp_rate = self.fix_spe_rate, minbr = self.min_brl)
            self.next(first_setting)


    def search(self, strategy = "H1", reroot = False):
        if strategy == "H1":
            self.H1(reroot)
        elif strategy == "H2":
            self.H2(reroot)
        elif strategy == "H3":
            self.H3(reroot)
        elif strategy == "Brutal":
            self.Brutal(reroot)
        else:
            self.H0(reroot)


    def count_species(self, print_log = True, pv = 0.001):
        lhr = lh_ratio_test(self.null_logl, self.max_logl, 1)
        pvalue = lhr.get_p_value()
        if print_log:
            print("Speciation rate: " + "{0:.3f}".format(self.max_setting.rate2))
            print("Coalesecnt rate: " + "{0:.3f}".format(self.max_setting.rate1))
            print("Null logl: " + "{0:.3f}".format(self.null_logl))
            print("MAX logl: " + "{0:.3f}".format(self.max_logl))
            print("P-value: " + "{0:.3f}".format(pvalue))
            spefit, speaw = self.max_setting.e2.ks_statistic()
            coafit, coaaw = self.max_setting.e1.ks_statistic()
            print("Kolmogorov-Smirnov test for model fitting:")
            print("Speciation: " + "Dtest = {0:.3f}".format(spefit) + " " + speaw)
            print("Coalescent: " + "Dtest = {0:.3f}".format(coafit) + " " + coaaw)
        if pvalue < pv:
            num_sp, self.species_list = self.max_setting.count_species()
            return num_sp
        else:
            self.species_list = []
            self.species_list.append(self.tree.get_leaf_names()) 
            return 1


    def whitening_search(self, strategy = "H1", reroot = False, pv = 0.001):
        self.search(strategy, reroot, pv)
        num_sp, self.species_list = self.max_setting.count_species()
        spekeep = self.max_setting.whiten_species()
        self.tree.prune(spekeep)
        self.max_logl = float("-inf") 
        self.max_setting = None
        self.null_logl = 0.0
        self.null_model()
        self.species_list = None
        self.counter = 0
        self.setting_set = set([])
        self.search(strategy, reroot, pv)


    def print_species(self):
        cnt = 1
        for sp in self.species_list:
            print("Species " + repr(cnt) + ":")
            for leaf in sp:
                print("          " + leaf)
            cnt = cnt + 1


    def output_species(self, taxa_order = []):
        """taxa_order is a list of taxa names, the paritions will be output as the same order"""
        if len(taxa_order) == 0:
            taxa_order = self.tree.get_leaf_names()
        
        num_taxa = 0
        for sp in self.species_list:
            for leaf in sp:
                num_taxa = num_taxa + 1
        if not len(taxa_order) == num_taxa:
            print("error error, taxa_order != num_taxa!")
            return None, None
        else: 
            partion = [-1] * num_taxa
            cnt = 1
            for sp in self.species_list:
                for leaf in sp:
                    idx = taxa_order.index(leaf)
                    partion[idx] = cnt
                cnt = cnt + 1
            return taxa_order, partion
示例#26
0
class exponential_mixture:
    """ML search PTP, to use: __init__(), search() and count_species()"""
    def __init__(
        self,
        tree,
        sp_rate=0,
        fix_sp_rate=False,
        max_iters=20000,
        min_br=0.0001,
    ):
        self.min_brl = min_br
        self.tree = Tree(tree, format=1)
        self.tree.resolve_polytomy(recursive=True)
        self.tree.dist = 0.0
        self.fix_spe_rate = fix_sp_rate
        self.fix_spe = sp_rate
        self.max_logl = float("-inf")
        self.max_setting = None
        self.null_logl = 0.0
        self.null_model()
        self.species_list = None
        self.counter = 0
        self.setting_set = set([])
        self.max_num_search = max_iters

    def null_model(self):
        coa_br = []
        all_nodes = self.tree.get_descendants()
        for node in all_nodes:
            if node.dist > self.min_brl:
                coa_br.append(node.dist)
        e1 = exp_distribution(coa_br)
        self.null_logl = e1.sum_log_l()
        return e1.rate

    def __compare_node(self, node):
        return node.dist

    def re_rooting(self):
        node_list = self.tree.get_descendants()
        node_list.sort(key=self.__compare_node)
        node_list.reverse()
        rootnode = node_list[0]
        self.tree.set_outgroup(rootnode)
        self.tree.dist = 0.0

    def comp_num_comb(self):
        for node in self.tree.traverse(strategy="postorder"):
            if node.is_leaf():
                node.add_feature("cnt", 1.0)
            else:
                acum = 1.0
                for child in node.get_children():
                    acum = acum * child.cnt
                acum = acum + 1.0
                node.add_feature("cnt", acum)
        return self.tree.cnt

    def next(self, sp_setting):
        self.setting_set.add(frozenset(sp_setting.spe_nodes))
        logl = sp_setting.get_log_l()
        if logl > self.max_logl:
            self.max_logl = logl
            self.max_setting = sp_setting
        for node in sp_setting.active_nodes:
            if node.is_leaf():
                pass
            else:
                childs = node.get_children()
                sp_nodes = []
                for child in childs:
                    sp_nodes.append(child)
                for nod in sp_setting.spe_nodes:
                    sp_nodes.append(nod)
                new_sp_setting = species_setting(
                    spe_nodes=sp_nodes,
                    root=sp_setting.root,
                    sp_rate=sp_setting.spe_rate,
                    fix_sp_rate=sp_setting.fix_spe_rate,
                    minbr=self.min_brl,
                )
                if frozenset(sp_nodes) in self.setting_set:
                    pass
                else:
                    self.next(new_sp_setting)

    def H0(self, reroot=True):
        self.H1(reroot)
        self.H2(reroot=False)
        self.run_h3(reroot=False)

    def H1(self, reroot=True):
        if reroot:
            self.re_rooting()

        # self.init_tree()
        sorted_node_list = self.tree.get_descendants()
        sorted_node_list.sort(key=self.__compare_node)
        sorted_node_list.reverse()

        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        first_setting = species_setting(
            spe_nodes=first_node_list,
            root=self.tree,
            sp_rate=self.fix_spe,
            fix_sp_rate=self.fix_spe_rate,
            minbr=self.min_brl,
        )
        last_setting = first_setting
        max_logl = last_setting.get_log_l()
        max_setting = last_setting

        for node in sorted_node_list:
            if node not in last_setting.spe_nodes:
                curr_sp_nodes = []
                for nod in last_setting.spe_nodes:
                    curr_sp_nodes.append(nod)

                chosen_branching_node = (node.up
                                         )  # find the father of this new node
                if chosen_branching_node in last_setting.spe_nodes:
                    for nod in chosen_branching_node.get_children():
                        if nod not in curr_sp_nodes:
                            curr_sp_nodes.append(nod)
                else:
                    for nod in chosen_branching_node.get_children():
                        if nod not in curr_sp_nodes:
                            curr_sp_nodes.append(nod)
                    while not chosen_branching_node.is_root():
                        chosen_branching_node = chosen_branching_node.up
                        for nod in chosen_branching_node.get_children():
                            if nod not in curr_sp_nodes:
                                curr_sp_nodes.append(nod)
                        if chosen_branching_node in last_setting.spe_nodes:
                            break
                new_setting = species_setting(
                    spe_nodes=curr_sp_nodes,
                    root=self.tree,
                    sp_rate=self.fix_spe,
                    fix_sp_rate=self.fix_spe_rate,
                    minbr=self.min_brl,
                )
                new_logl = new_setting.get_log_l()
                if new_logl > max_logl:
                    max_logl = new_logl
                    max_setting = new_setting
                last_setting = new_setting

            else:
                """node already is a speciation node, do nothing"""
                pass

        if max_logl > self.max_logl:
            self.max_logl = max_logl
            self.max_setting = max_setting

    def H2(self, reroot=True):
        """Greedy"""
        if reroot:
            self.re_rooting()

        # self.init_tree()
        sorted_node_list = self.tree.get_descendants()
        sorted_node_list.sort(key=self.__compare_node)
        sorted_node_list.reverse()

        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        first_setting = species_setting(
            spe_nodes=first_node_list,
            root=self.tree,
            sp_rate=self.fix_spe,
            fix_sp_rate=self.fix_spe_rate,
            minbr=self.min_brl,
        )
        last_setting = first_setting
        max_logl = last_setting.get_log_l()
        max_setting = last_setting
        contin_flag = True

        while contin_flag:
            curr_max_logl = float("-inf")
            curr_max_setting = None
            contin_flag = False
            for node in last_setting.active_nodes:
                if node.is_leaf():
                    pass
                else:
                    contin_flag = True
                    childs = node.get_children()
                    sp_nodes = []
                    for child in childs:
                        sp_nodes.append(child)
                    for nod in last_setting.spe_nodes:
                        sp_nodes.append(nod)
                    new_sp_setting = species_setting(
                        spe_nodes=sp_nodes,
                        root=self.tree,
                        sp_rate=self.fix_spe,
                        fix_sp_rate=self.fix_spe_rate,
                        minbr=self.min_brl,
                    )
                    logl = new_sp_setting.get_log_l()
                    if logl > curr_max_logl:
                        curr_max_logl = logl
                        curr_max_setting = new_sp_setting

            if curr_max_logl > max_logl:
                max_setting = curr_max_setting
                max_logl = curr_max_logl

            last_setting = curr_max_setting

        if max_logl > self.max_logl:
            self.max_logl = max_logl
            self.max_setting = max_setting

    def run_h3(self, reroot=True):
        if reroot:
            self.re_rooting()
        sorted_node_list = self.tree.get_descendants()
        sorted_node_list.sort(key=self.__compare_node)
        sorted_node_list.reverse()
        sorted_br = []
        for node in sorted_node_list:
            sorted_br.append(node.dist)
        maxlogl = float("-inf")
        maxidx = -1
        for i in range(len(sorted_node_list))[1:]:
            l1 = sorted_br[0:i]
            l2 = sorted_br[i:]
            e1 = exp_distribution(l1)
            e2 = exp_distribution(l2)
            logl = e1.sum_log_l() + e2.sum_log_l()
            if logl > maxlogl:
                maxidx = i
                maxlogl = logl

        target_nodes = sorted_node_list[0:maxidx]

        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        first_setting = species_setting(
            spe_nodes=first_node_list,
            root=self.tree,
            sp_rate=self.fix_spe,
            fix_sp_rate=self.fix_spe_rate,
            minbr=self.min_brl,
        )
        last_setting = first_setting
        max_logl = last_setting.get_log_l()
        max_setting = last_setting
        contin_flag = True
        target_node_cnt = 0
        while contin_flag:
            curr_max_logl = float("-inf")
            curr_max_setting = None
            contin_flag = False
            unchanged_flag = True
            for node in last_setting.active_nodes:
                if node.is_leaf():
                    pass
                else:
                    contin_flag = True
                    childs = node.get_children()
                    sp_nodes = []
                    flag = False
                    for child in childs:
                        if child in target_nodes:
                            flag = True
                    # target_nodes.remove(child)
                    if flag:
                        unchanged_flag = False
                        for child in childs:
                            sp_nodes.append(child)
                        for nod in last_setting.spe_nodes:
                            sp_nodes.append(nod)
                        new_sp_setting = species_setting(
                            spe_nodes=sp_nodes,
                            root=self.tree,
                            sp_rate=self.fix_spe,
                            fix_sp_rate=self.fix_spe_rate,
                            minbr=self.min_brl,
                        )
                        logl = new_sp_setting.get_log_l()
                        if logl > curr_max_logl:
                            curr_max_logl = logl
                            curr_max_setting = new_sp_setting
            if not unchanged_flag:
                target_node_cnt = target_node_cnt + 1
                if curr_max_logl > max_logl:
                    max_setting = curr_max_setting
                    max_logl = curr_max_logl
                last_setting = curr_max_setting

            if len(target_nodes) == target_node_cnt:
                contin_flag = False
            if contin_flag and unchanged_flag and last_setting != None:
                for node in last_setting.active_nodes:
                    if node.is_leaf():
                        pass
                    else:
                        childs = node.get_children()
                        sp_nodes = []
                        for child in childs:
                            sp_nodes.append(child)
                        for nod in last_setting.spe_nodes:
                            sp_nodes.append(nod)
                        new_sp_setting = species_setting(
                            spe_nodes=sp_nodes,
                            root=self.tree,
                            sp_rate=self.fix_spe,
                            fix_sp_rate=self.fix_spe_rate,
                            minbr=self.min_brl,
                        )
                        logl = new_sp_setting.get_log_l()
                        if logl > curr_max_logl:
                            curr_max_logl = logl
                            curr_max_setting = new_sp_setting
                if curr_max_logl > max_logl:
                    max_setting = curr_max_setting
                    max_logl = curr_max_logl
                last_setting = curr_max_setting

        if max_logl > self.max_logl:
            self.max_logl = max_logl
            self.max_setting = max_setting

    def Brutal(self, reroot=False):
        if reroot:
            self.re_rooting()
        first_node_list = []
        first_node_list.append(self.tree)
        first_childs = self.tree.get_children()
        for child in first_childs:
            first_node_list.append(child)
        num_s = self.comp_num_comb()
        if num_s > self.max_num_search:
            print("Too many search iterations: " + repr(num_s) +
                  ", using H0 instead!!!")
            self.H0(reroot=False)
        else:
            first_setting = species_setting(
                spe_nodes=first_node_list,
                root=self.tree,
                sp_rate=self.fix_spe,
                fix_sp_rate=self.fix_spe_rate,
                minbr=self.min_brl,
            )
            self.next(first_setting)

    def search(self, strategy="H1", reroot=False):
        if strategy == "H1":
            self.H1(reroot)
        elif strategy == "H2":
            self.H2(reroot)
        elif strategy == "H3":
            self.run_h3(reroot)
        elif strategy == "Brutal":
            self.Brutal(reroot)
        else:
            self.H0(reroot)

    def count_species(self, print_log=True, pv=0.001):
        lhr = lh_ratio_test(self.null_logl, self.max_logl, 1)
        pvalue = lhr.get_p_value()
        if print_log:
            print("Speciation rate: " +
                  "{0:.3f}".format(self.max_setting.rate2))
            print("Coalesecnt rate: " +
                  "{0:.3f}".format(self.max_setting.rate1))
            print("Null logl: " + "{0:.3f}".format(self.null_logl))
            print("MAX logl: " + "{0:.3f}".format(self.max_logl))
            print("P-value: " + "{0:.3f}".format(pvalue))
            spefit, speaw = self.max_setting.e2.ks_statistic()
            coafit, coaaw = self.max_setting.e1.ks_statistic()
            print("Kolmogorov-Smirnov test for model fitting:")
            print("Speciation: " + "Dtest = {0:.3f}".format(spefit) + " " +
                  speaw)
            print("Coalescent: " + "Dtest = {0:.3f}".format(coafit) + " " +
                  coaaw)
        if pvalue < pv:
            num_sp, self.species_list = self.max_setting.count_species()
            return num_sp
        else:
            self.species_list = []
            self.species_list.append(self.tree.get_leaf_names())
            return 1

    def whitening_search(self, strategy="H1", reroot=False, pv=0.001):
        self.search(strategy, reroot, pv)
        num_sp, self.species_list = self.max_setting.count_species()
        spekeep = self.max_setting.whiten_species()
        self.tree.prune(spekeep)
        self.max_logl = float("-inf")
        self.max_setting = None
        self.null_logl = 0.0
        self.null_model()
        self.species_list = None
        self.counter = 0
        self.setting_set = set([])
        self.search(strategy, reroot, pv)

    def print_species(self):
        cnt = 1
        for sp in self.species_list:
            print("Species " + repr(cnt) + ":")
            for leaf in sp:
                print("          " + leaf)
            cnt = cnt + 1

    def output_species(self, taxa_order=[]):
        """taxa_order is a list of taxa names, the paritions will be output as the same order"""
        if len(taxa_order) == 0:
            taxa_order = self.tree.get_leaf_names()

        num_taxa = 0
        for sp in self.species_list:
            for leaf in sp:
                num_taxa = num_taxa + 1
        if not len(taxa_order) == num_taxa:
            print("error error, taxa_order != num_taxa!")
            return None, None
        else:
            partion = [-1] * num_taxa
            cnt = 1
            for sp in self.species_list:
                for leaf in sp:
                    idx = taxa_order.index(leaf)
                    partion[idx] = cnt
                cnt = cnt + 1
            return taxa_order, partion
示例#27
0
def cladecluster_bysimilarity(treefpath, pwiddf_fpath, outgroup,
                              cluster_minpwid, cluster_minsize):
    loneraccs_ = []
    print(treefpath)
    t = Tree(treefpath)
    t.set_outgroup(t & outgroup)
    print(len(t.get_leaf_names()))
    pwid_df = pd.read_pickle(pwiddf_fpath)
    #    tovisit_.append(rnode)
    tovisit_ = [t]
    while (len(tovisit_) > 0):
        node = tovisit_.pop()
        numleaves, tlength, avgpwid = clusterstats(node, pwid_df)
        print(numleaves, tlength, avgpwid)
        if numleaves > cluster_minsize and avgpwid > cluster_minpwid:
            print('building a new leaf group')
            groupaccs_ = node.get_leaf_names()
            print('group size= {}'.format(len(groupaccs_)))
            node.add_feature("accs", groupaccs_)
            nc = node.children[:]
            for x, c in enumerate(nc):  #node.children:
                node.remove_child(c)
        else:
            tovisit_.extend(node.children)

    clusternodes_ = []
    orphanleaves_ = []
    totalaccs = 0
    for node in t.traverse():
        if node.is_leaf():
            try:
                numaccs = len(set(node.accs))
                clusternodes_.append(node)
                totalaccs += numaccs
            except:
                orphanleaves_.append(node)
    for orphanleaf in orphanleaves_:
        orphacc = orphanleaf.name
        bestcluster = None
        bestcluster_pwid = 0.0
        for clusternode in clusternodes_:
            clusteraccs_ = clusternode.accs
            sub_pwdf = pwid_df.loc[orphacc, clusteraccs_]
            avgpwid = sub_pwdf.mean()
            if avgpwid > bestcluster_pwid:
                bestcluster_pwid = avgpwid
                bestcluster = clusternode
        bestcluster.accs.append(orphacc)
        orphanleaf.delete()  #detach()


#    for node in t.traverse():
#        if node.is_leaf():
#            try:
#                print(len(node.accs))
#            except:
#                node.detach()
#
    for node in t.traverse():
        #print(node.is_root(),len(node.children),node.is_leaf())
        try:
            print(len(node.accs))
            accstr = 'cluster|' + node.accs[0]
            for acc in node.accs[1:]:
                accstr += '|' + acc
            node.add_feature('name', accstr)
        except:
            pass
    newtreefpath = os.path.join(os.path.split(treefpath)[0], "newtree.nw")
    print(newtreefpath)
    t.write(outfile=newtreefpath, features=['name'])
示例#28
0
def generax2mcmctree(xml_file,
                     stree,
                     gene,
                     dating_o,
                     calbration_file,
                     genome2cog25={}):
    """
    generate files for mcmctree, including 
    1. list of used genomes
    2. used genomes (itol annotation)
    3. constructed species tree topology for dating
    4. target genomes in the complete phylogeny of phylum (itol annotation)
    5. calibration file and tree file with calibrations information
    """
    # xml_file = join(r_odir,f'reconciliations/{gene}_reconciliated.xml')
    # stree = f"./trees/iqtree/{phylum_name}.reroot.newick"

    phylum_name = xml_file.split('/')[-4]
    st = Tree(stree, format=3)
    tmp_name = phylum_name + '_' + gene
    _p2node, _p2node_transfer_receptor = get_p2node(xml_file,
                                                    stree,
                                                    key=tmp_name)
    target_nodes = list(_p2node.values())[0] + list(
        _p2node_transfer_receptor.values())[0]

    must_in_genomes = open(
        "/mnt/home-backup/thliao/cyano_basal/rawdata/assembly_ids.list").read(
        ).strip('\n').split('\n')
    # new calibrations are /mnt/home-backup/thliao/cyano/ref_genomes_list.txt
    cluster2genomes = get_cluster(
        stree.replace('.reroot.newick', '.clusterd.list'))
    g2cluster = {v: c for c, d in cluster2genomes.items() for v in d}
    retained_ids = sampling(st,
                            target_nodes,
                            must_in=must_in_genomes,
                            node2cluster=g2cluster,
                            genome2cog25=genome2cog25)

    text = to_binary_shape({g: ['keep']
                            for g in retained_ids},
                           {"keep": {
                               "color": "#88b719"
                           }})
    text = to_color_range({g: 'keep'
                           for g in retained_ids}, {"keep": "#88b719"})
    with open(join(dating_o, f'id_list/{phylum_name}_{gene}.txt'), 'w') as f1:
        f1.write(text)
    with open(join(dating_o, f'id_list/{phylum_name}_{gene}.list'), 'w') as f1:
        f1.write('\n'.join(retained_ids))
    print(phylum_name, len(st.get_leaf_names()), len(retained_ids))
    st.copy()
    st.prune(retained_ids)
    with open(join(dating_o, f'species_trees/{phylum_name}_{gene}.newick'),
              'w') as f1:
        f1.write(st.write(format=9))

    # draw target nodes
    LCA_nodes = []
    for name in target_nodes:
        n = [n for n in st.traverse() if n.name == name][0]
        l1 = n.children[0].get_leaf_names()[0]
        l2 = n.children[1].get_leaf_names()[0]
        LCA_nodes.append(f"{l1}|{l2}")

    text = pie_chart({n: {
        'speciation': 1
    }
                      for n in LCA_nodes}, {"speciation": "#ff0000"},
                     dataset_label='GeneRax results')
    with open(join(dating_o, f'target_nodes/{phylum_name}_{gene}.txt'),
              'w') as f1:
        f1.write(text)

    # new set file set14
    # set14_f = './dating/calibration_sets/scheme1/cal_set14.txt'

    c = 'GCA_000011385.1'
    n = [_ for _ in st.children if c not in _.get_leaf_names()][0]
    final_text = open(calbration_file).read().replace('GCA_002239005.1',
                                                      n.get_leaf_names()[0])
    with open(join(dating_o, f'calibrations/{phylum_name}_{gene}_set14.txt'),
              'w') as f1:
        f1.write(final_text)
def parse_pda(handle):
    for line in handle.readlines():
        if len(line) > 100:
            return line.strip()


if __name__ == "__main__":
    args = parser_code()
    input_tree = args.input_tree
    output_tree = args.output_tree
    accession_file = args.accession
    difference = '/'.join(input_tree.split('/')[:-1]) + "/not_included.txt"
    filter_file = '/'.join(input_tree.split('/')[:-1]) + "/filter.txt"
    t = Tree(input_tree, format=1)
    leaves_full = t.get_leaf_names()
    size = int(args.tree_size)
    keep = args.keeper
    cmd1 = "./pda -k {} {} {} -if {}".format(size, input_tree, output_tree,
                                             keep)
    os.system(cmd1)
    tree = parse_pda(open(output_tree, "r"))
    t = Tree(tree, format=1)
    leaves_partial = t.get_leaf_names()
    print(leaves_partial)
    not_included = [item for item in leaves_full if item not in leaves_partial]
    t.set_outgroup("KE136308.1")
    t.write(outfile=output_tree, format=1)
    outfile = open(difference, 'w')
    for item in not_included:
        outfile.write(item + "\n")
示例#30
0
def cli(gnumber,
        glist,
        gtree,
        edprob,
        gsize,
        glen_range,
        dnds,
        tau=None,
        delrate=0.0,
        from_al=None,
        protlike=False,
        no_syn=False,
        sub_rate=1.0,
        min_cons=0.0,
        outdir=""):
    """Extract genome content based on a list of species """
    gleaf = []
    no_edit = []
    tree = None
    if gnumber:
        gleaf = ['Genome_{}'.format(i) for i in range(1, gnumber + 1)]
    elif glist:
        with open(glist) as G:
            for line in Glist:
                line = line.strip()
                if line and not line.startswith('#'):
                    gleaf.append(line.strip('-_'))
                    if line.startswith('-') or line.startswith('_'):
                        no_edit.append(line.strip('-_'))
    elif gtree:
        tree = Tree(gtree)
        gleaf = tree.get_leaf_names()
        no_edit = [x.strip('_') for x in gleaf if x.startswith('_')]
        for node in tree:
            node.name = node.name.strip('_')

    else:
        raise NotImplementedError(
            "One of --gnumber, --glist and --gtree is needed !")

    if not tree:
        tree = Tree()
        tree.populate(len(gleaf), names_library=gleaf, random_branches=True)

    param_list = {"alpha": dnds[1], "beta": dnds[0]}
    if tau:
        param_list.update({"kappa": tau})

    if from_al:  # read codons frequencies from an existing alignment
        f = pyvolve.ReadFrequencies("codon", file=from_al)
        param_list.update({'state_freqs': f.compute_frequencies()})

    #print(tree.get_ascii(show_internal=True, attributes=['name', 'dist']))
    phylogeny = pyvolve.read_tree(tree=tree.write(format=5),
                                  scale_tree=sub_rate)
    codon_model = pyvolve.Model("codon", param_list)  #, neutral_scaling=True)
    sequences = []
    edited_sequences = []
    truth_tables = []
    # add height to tree
    tree = add_height_to_tree(tree)

    for i in range(gsize):
        # gene length is given from an uniform distribution
        alen = np.random.randint(glen_range[0], glen_range[1]) * 3
        seq = simulate_genomes(codon_model, phylogeny, alen, outdir, i + 1)
        if delrate:
            seq = random_deletion(seq, tree, alen // 3, delrate)
        if protlike:
            for k in seq:
                seq[k] = 'ATG' + seq[k]
        sequences.append(seq)
        edited_seq, truth_table = CtoUsimulate(seq,
                                               tree,
                                               no_edit,
                                               edprob,
                                               no_syn=no_syn,
                                               min_cons=min_cons)
        edited_sequences.append(edited_seq)
        truth_tables.append(truth_table)
        save_data(tree, seq, edited_seq, truth_table, outdir, i + 1)
示例#31
0
def draw_tree(ax,
              tx,
              rmargin=.3,
              treecolor="k",
              leafcolor="k",
              supportcolor="k",
              outgroup=None,
              reroot=True,
              gffdir=None,
              sizes=None,
              trunc_name=None,
              SH=None,
              scutoff=0,
              barcodefile=None,
              leafcolorfile=None,
              leaffont=12):
    """
    main function for drawing phylogenetic tree
    """

    t = Tree(tx)
    if reroot:
        if outgroup:
            R = t.get_common_ancestor(*outgroup)
        else:
            # Calculate the midpoint node
            R = t.get_midpoint_outgroup()

        if R != t:
            t.set_outgroup(R)

    farthest, max_dist = t.get_farthest_leaf()

    margin = .05
    xstart = margin
    ystart = 1 - margin
    canvas = 1 - rmargin - 2 * margin
    tip = .005
    # scale the tree
    scale = canvas / max_dist

    num_leaves = len(t.get_leaf_names())
    yinterval = canvas / (num_leaves + 1)

    # get exons structures, if any
    structures = {}
    if gffdir:
        gffiles = glob("{0}/*.gff*".format(gffdir))
        setups, ratio = get_setups(gffiles, canvas=rmargin / 2, noUTR=True)
        structures = dict((a, (b, c)) for a, b, c in setups)

    if sizes:
        sizes = Sizes(sizes).mapping

    if barcodefile:
        barcodemap = DictFile(barcodefile, delimiter="\t")

    if leafcolorfile:
        leafcolors = DictFile(leafcolorfile, delimiter="\t")

    coords = {}
    i = 0
    for n in t.traverse("postorder"):
        dist = n.get_distance(t)
        xx = xstart + scale * dist

        if n.is_leaf():
            yy = ystart - i * yinterval
            i += 1

            if trunc_name:
                name = truncate_name(n.name, rule=trunc_name)
            else:
                name = n.name

            if barcodefile:
                name = decode_name(name, barcodemap)

            sname = name.replace("_", "-")

            try:
                lc = leafcolors[n.name]
            except Exception:
                lc = leafcolor
            else:
                # if color is given as "R,G,B"
                if "," in lc:
                    lc = map(float, lc.split(","))

            ax.text(xx + tip,
                    yy,
                    sname,
                    va="center",
                    fontstyle="italic",
                    size=leaffont,
                    color=lc)

            gname = n.name.split("_")[0]
            if gname in structures:
                mrnabed, cdsbeds = structures[gname]
                ExonGlyph(ax,
                          1 - rmargin / 2,
                          yy,
                          mrnabed,
                          cdsbeds,
                          align="right",
                          ratio=ratio)
            if sizes and gname in sizes:
                size = sizes[gname]
                size = size / 3 - 1  # base pair converted to amino acid
                size = "{0}aa".format(size)
                ax.text(1 - rmargin / 2 + tip, yy, size, size=leaffont)

        else:
            children = [coords[x] for x in n.get_children()]
            children_x, children_y = zip(*children)
            min_y, max_y = min(children_y), max(children_y)
            # plot the vertical bar
            ax.plot((xx, xx), (min_y, max_y), "-", color=treecolor)
            # plot the horizontal bar
            for cx, cy in children:
                ax.plot((xx, cx), (cy, cy), "-", color=treecolor)
            yy = sum(children_y) * 1. / len(children_y)
            support = n.support
            if support > 1:
                support = support / 100.
            if not n.is_root():
                if support > scutoff / 100.:
                    ax.text(xx,
                            yy + .005,
                            "{0:d}".format(int(abs(support * 100))),
                            ha="right",
                            size=leaffont,
                            color=supportcolor)

        coords[n] = (xx, yy)

    # scale bar
    br = .1
    x1 = xstart + .1
    x2 = x1 + br * scale
    yy = ystart - i * yinterval
    ax.plot([x1, x1], [yy - tip, yy + tip], "-", color=treecolor)
    ax.plot([x2, x2], [yy - tip, yy + tip], "-", color=treecolor)
    ax.plot([x1, x2], [yy, yy], "-", color=treecolor)
    ax.text((x1 + x2) / 2,
            yy - tip,
            "{0:g}".format(br),
            va="top",
            ha="center",
            size=leaffont,
            color=treecolor)

    if SH is not None:
        xs = x1
        ys = (margin + yy) / 2.
        ax.text(xs,
                ys,
                "SH test against ref tree: {0}".format(SH),
                ha="left",
                size=leaffont,
                color="g")

    normalize_axes(ax)
示例#32
0
        args.taxa)  # Yes, I know. No sanitation of user input at all.
except Exception as e:
    logger.error("Error in regular expression.")
    logger.error(e)
    exit(1)

# Loading and preparing the primary phylogenetic tree
logger.info("Loading and pruning the primary tree.")
try:
    main_tree = Tree(args.tree)
except Exception as e:
    logger.error("Error loading the phylogenetic tree from file {}.".format(
        args.tree))
    logger.error(e)
    exit(1)
main_tree_leaves_all = set(main_tree.get_leaf_names())

# Get the leaves we want to preserve
selected_main_leave_names = [
    leaf for leaf in main_tree_leaves_all if leaf_regex.match(leaf)
]
logger.info("The following leaf nodes matching your query were found:")
for leaf in selected_main_leave_names:
    logger.info(leaf)

# Prune the main tree
logger.info("Pruning the tree.")
main_tree.prune(
    selected_main_leave_names,
    preserve_branch_length=True)  # we want to preserve branch lengths!
logger.info(main_tree.write())
示例#33
0
matrix = {}
def get_dist(a, b):
    if a == b:
        return 0.0
    try:
        return matrix[(a, b)]
    except KeyError:
        return matrix[(b, a)]

for tip_a, tip_b in itertools.permutations(lineages.keys(), 2):
    d = sum([n.dist for n in lineages[tip_a] ^ lineages[tip_b]])
    matrix[(tip_a, tip_b)] = d
    #if len(matrix) % 10000 == 0:
    #    print >>sys.stderr, len(matrix)

leaves = t.get_leaf_names()
print '\t'.join(['#names'] + leaves)
for tip_a in leaves:
    row = [tip_a]
    for tip_b in leaves:
        row.append(get_dist(tip_a, tip_b))
    print '\t'.join(map(str, row))


# test

import random
s = random.sample(matrix.keys(), 1000)
for a,b in s:
    d0 = get_dist(a, b)
    d1 = t.get_distance(a, b)