def main(argv=None): """script main. parses command line options in sys.argv, unless *argv* is given. """ if argv is None: argv = sys.argv parser = E.OptionParser(version="%prog version: $Id: trees2tree.py 2782 2009-09-10 11:40:29Z andreas $", usage=globals()["__doc__"]) parser.add_option("-m", "--method", dest="method", type="choice", choices=("counts", "min", "max", "sum", "mean", "median", "stddev", "non-redundant", "consensus", "select-largest"), help="aggregation function.") parser.add_option("-r", "--regex-id", dest="regex_id", type="string", help="regex pattern to extract identifier from tree name for the selection functions.") parser.add_option("-w", "--write-values", dest="write_values", type="string", help="if processing multiple trees, write values to file.") parser.add_option("-e", "--error-branchlength", dest="error_branchlength", type="float", help="set branch length without counts to this value.") parser.set_defaults( method="mean", regex_id=None, filtered_branch_lengths=(-999.0, 999.0), write_values = None, error_branchlength = None, separator=":", ) (options, args) = E.Start(parser, add_pipe_options=True) if options.loglevel >= 2: options.stdlog.write("# reading trees from stdin.\n") options.stdlog.flush() nexus = TreeTools.Newick2Nexus(sys.stdin) if options.loglevel >= 1: options.stdlog.write( "# read %i trees from stdin.\n" % len(nexus.trees)) nskipped = 0 ninput = len(nexus.trees) noutput = 0 nerrors = 0 if options.method == "non-redundant": # compute non-redudant trees template_trees = [] template_counts = [] ntree = 0 for tree in nexus.trees: for x in range(0, len(template_trees)): is_compatible, reason = TreeTools.IsCompatible( tree, template_trees[x]) if is_compatible: template_counts[x] += 1 break else: template_counts.append(1) template_trees.append(tree) if options.loglevel >= 2: options.stdlog.write( "# tree=%i, ntemplates=%i\n" % (ntree, len(template_trees))) ntree += 1 for x in range(0, len(template_trees)): if options.loglevel >= 1: options.stdlog.write("# tree: %i, counts: %i, percent=%5.2f\n" % (x, template_counts[x], template_counts[x] * 100.0 / ntotal)) options.stdout.write( TreeTools.Tree2Newick(template_trees[x]) + "\n") elif options.method in ("select-largest",): # select one of the trees with the same name. clusters = {} for x in range(0, len(nexus.trees)): n = nexus.trees[x].name if options.regex_id: n = re.search(options.regex_id, n).groups()[0] if n not in clusters: clusters[n] = [] clusters[n].append(x) new_trees = [] for name, cluster in clusters.items(): new_trees.append( getBestTree([nexus.trees[x] for x in cluster], options.method)) for x in range(0, len(new_trees)): options.stdout.write(">%s\n" % new_trees[x].name) options.stdout.write(TreeTools.Tree2Newick(new_trees[x],) + "\n") noutput += 1 nskipped = ntotal - noutput elif options.method == "consensus": phylip = WrapperPhylip.Phylip() phylip.setLogLevel(options.loglevel - 2) phylip.setProgram("consense") phylip_options = [] phylip_options.append("Y") phylip.setOptions(phylip_options) phylip.setTrees(nexus.trees) result = phylip.run() options.stdout.write( "# consensus tree built from %i trees\n" % (phylip.mNInputTrees)) options.stdout.write( TreeTools.Tree2Newick(result.mNexus.trees[0]) + "\n") noutput = 1 else: if options.method in ("min", "max", "sum", "mean", "counts"): xtree = nexus.trees[0] for n in xtree.chain.keys(): if xtree.node(n).data.branchlength in options.filtered_branch_lengths: xtree.node(n).data.branchlength = 0 ntotals = [1] * len(xtree.chain.keys()) if options.method == "min": f = min elif options.method == "max": f = max elif options.method == "sum": f = lambda x, y: x + y elif options.method == "mean": f = lambda x, y: x + y elif options.method == "counts": f = lambda x, y: x + 1 for n in xtree.chain.keys(): if xtree.node(n).data.branchlength not in options.filtered_branch_lengths: xtree.node(n).data.branchlength = 1 else: xtree.node(n).data.branchlength = 0 else: raise "unknown option %s" % options.method for tree in nexus.trees[1:]: for n in tree.chain.keys(): if tree.node(n).data.branchlength not in options.filtered_branch_lengths: xtree.node(n).data.branchlength = f( xtree.node(n).data.branchlength, tree.node(n).data.branchlength) ntotals[n] += 1 if options.method == "mean": for n in xtree.chain.keys(): if ntotals[n] > 0: xtree.node(n).data.branchlength = float( xtree.node(n).data.branchlength) / ntotals[n] else: if options.error_branchlength is not None: xtree.node( n).data.branchlength = options.error_branchlength if options.loglevel >= 1: options.stdlog.write( "# no counts for node %i - set to %f\n" % (n, options.error_branchlength)) nerrors += 1 else: raise "no counts for node %i" % n else: # collect all values for trees values = [[] for x in range(TreeTools.GetSize(nexus.trees[0]))] for tree in nexus.trees: for n, node in tree.chain.items(): if node.data.branchlength not in options.filtered_branch_lengths: values[n].append(node.data.branchlength) tree = nexus.trees[0] for n, node in tree.chain.items(): if len(values[n]) > 0: if options.method == "stddev": node.data.branchlength = scipy.std(values[n]) elif options.method == "median": node.data.branchlength = scipy.median(values[n]) else: if options.error_branchlength is not None: node.data.branchlength = options.error_branchlength if options.loglevel >= 1: options.stdlog.write( "# no counts for node %i - set to %f\n" % (n, options.error_branchlength)) nerrors += 1 else: raise "no counts for node %i" % n if options.write_values: outfile = open(options.write_values, "w") for n, node in tree.chain.items(): values[n].sort() id = options.separator.join( sorted(TreeTools.GetLeaves(tree, n))) outfile.write("%s\t%s\n" % (id, ";".join(map(str, values[n])))) outfile.close() del nexus.trees[1:] options.stdout.write(TreeTools.Nexus2Newick(nexus) + "\n") noutput = 1 if options.loglevel >= 1: options.stdlog.write("# ntotal=%i, nskipped=%i, noutput=%i, nerrors=%i\n" % ( ninput, nskipped, noutput, nerrors)) E.Stop()
def getOrthologNodes(tree, positive_set, options, selector="strict", outgroups=None): """get all ortholog nodes in tree for species in positive_set. Depending on the selector function, different sets are returned: If selector is "strict", only strict orthologs are returned. These contain exactly one gene per species for all species in the positive_set. If selector is "degenerate", only degenerate orthologs are returned. These contain at least gene per species for species in the positive_set. Collect genes in tree for each species. Returns the node_id for which a set fulfills the criteria and the set for which it fulfills it. Avoid double counting: if you are interested in species A and B, any branches involving others species should be ignored. Make sure to only count once and not every time a discarded branch is removed. Thus, as soon as A and B merge, any node further up the tree have to be ignored. total_genes_function: if true, node is recorded total_species_function: if true, iteration stops """ nspecies = len(options.org2column) if selector == "strict": # strict orthologs: at most one gene per species exit_function = lambda num_genes_for_species: num_genes_for_species > 1 keep_function = lambda num_genes_for_species: num_genes_for_species == 1 total_genes_function = lambda num_genes_at_node, num_species_in_pattern: num_genes_at_node == num_species_in_pattern total_species_function = lambda num_species_at_node, num_species_in_pattern: num_species_at_node == num_species_in_pattern check_outgroup_function = lambda x: False negative_set = set() elif selector == "degenerate": # degenerate orthologs: any number of genes per species, exit_function = lambda num_genes_for_species: False keep_function = lambda num_genes_for_species: num_genes_for_species > 0 total_genes_function = lambda num_genes_at_node, num_species_in_pattern: num_genes_at_node > num_species_in_pattern total_species_function = lambda num_species_at_node, num_species_in_pattern: num_species_at_node == num_species_in_pattern check_outgroup_function = lambda x: False negative_set = set() elif selector == "lineage": # lineage specific duplications: at least 1 gene exit_function = lambda num_genes_for_species: False keep_function = lambda num_genes_for_species: num_genes_for_species > 1 total_genes_function = lambda num_genes_at_node, num_species_in_pattern: num_genes_at_node >= num_species_in_pattern total_species_function = lambda num_species_at_node, num_species_in_pattern: False check_outgroup_function = lambda x: False negative_set = set(range(nspecies)).difference(positive_set) elif selector == "any": # any number of orthologs, including # orphans exit_function = lambda num_genes_for_species: False keep_function = lambda num_genes_for_species: True total_genes_function = lambda num_genes_at_node, num_species_in_pattern: True total_species_function = lambda num_species_at_node, num_species_in_pattern: num_species_at_node == num_species_in_pattern check_outgroup_function = lambda x: False negative_set = set() elif selector == "outgroup": # group selector exit_function = lambda num_genes_for_species: False keep_function = lambda num_genes_for_species: num_genes_for_species > 0 total_genes_function = lambda num_genes_at_node, num_species_in_pattern: False total_species_function = lambda num_species_at_node, num_species_in_pattern: False # check for outgrup: needs to have outgroup and at least one other species # ie.: sum of all genes in outgroups larger than sum of all genes in # all species if not outgroups: raise "usage error: please supply outgroups if 'outgroup'-selector is chosen." check_outgroup_function = lambda genes: 0 < sum( [len(genes[x]) for x in outgroups]) < sum(map(lambda x: len(x), genes)) negative_set = set() else: raise "unknown selector %s" % selector # work here: set genes[node_id] to None, # 1. if the gene count for a species of interest is > 1 # 2. if the gene count for all species of interest is 1 in # the child node. if options.loglevel >= 5: options.stdlog.write("# gene tree\n") tree.display() n = TreeTools.GetSize(tree) + 1 genes = [] for x in range(n): genes.append([set() for x in range(nspecies)]) ortholog_nodes = [] def count_genes(node_id): """record number of genes per species for each node """ node = tree.node(node_id) if options.loglevel >= 6: options.stdlog.write("# node_id=%i\n" % node_id) if options.loglevel >= 10: options.stdlog.write("# sets=%s\n" % (str(genes))) # species in pattern num_species_in_pattern = len(positive_set) if node.succ: # process non-leaf node for s in node.succ: # propagate: terminated nodes force upper nodes to terminate # (assigned to None). if not genes[s]: genes[node_id] = None return # total number of genes at node num_genes_at_node = 0 # total number of species at node num_species_at_node = 0 # compute new gene set for each species at node for x in positive_set: genes[node_id][x] = genes[node_id][x].union(genes[s][x]) num_genes_for_species = len(genes[node_id][x]) if exit_function(num_genes_for_species): genes[node_id] = None return num_genes_at_node += num_genes_for_species if num_genes_for_species: num_species_at_node += 1 if options.loglevel >= 6: print "node=", node_id, "species_at_node", num_species_at_node, "genes_at_node=", num_genes_at_node, \ "num_genes_for_species=", num_genes_for_species, "ngenes=", sum( map(lambda x: len(x), genes[node_id])) options.stdlog.write("# genes at node %i\t%s\n" % (node_id, genes[node_id])) if outgroups: print sum([len(genes[node_id][x]) for x in outgroups]) print check_outgroup_function(genes[node_id]) # check stop criterion if total_species_function(num_species_at_node, num_species_in_pattern): # check if positive requirements are fulfilled for x in positive_set: if not keep_function(len(genes[node_id][x])): if options.loglevel >= 6: options.stdlog.write( "# keep function false for species %i\n" % x) break else: if total_genes_function(num_genes_at_node, num_species_in_pattern): if options.loglevel >= 6: options.stdlog.write("# recording node %i\n" % x) ortholog_nodes.append((node_id, genes[node_id])) genes[node_id] = None return elif check_outgroup_function(genes[node_id]): ortholog_nodes.append((node_id, genes[node_id])) genes[node_id] = None return elif negative_set: if total_genes_function(num_genes_at_node, num_species_in_pattern): if options.loglevel >= 6: options.stdlog.write("# recording node %i\n" % node_id) ortholog_nodes.append((node_id, genes[node_id])) else: # process leaf s, t, g, q = parseIdentifier(node.data.taxon, options) c = options.org2column[s] if c in positive_set: genes[node_id][c].add(g) elif c in negative_set: genes[node_id] = None tree.dfs(tree.root, post_function=count_genes) return ortholog_nodes
def getMergers(tree, map_strain2species, options): """merge strains to species. returns the new tree with species merged and a dictionary of genes including the genes that have been merged. Currently, only binary merges are supported. """ n = TreeTools.GetSize(tree) + 1 all_strains = map_strain2species.keys() all_species = map_strain2species.values() genes = [] for x in range(n): g = {} for s in all_strains: g[s] = set() genes.append(g) # build list of species pairs that can be joined. map_species2strain = IOTools.getInvertedDictionary(map_strain2species) pairs = [] for species, strains in map_species2strain.items(): for x in range(len(strains)): for y in range(0, x): pairs.append((strains[x], strains[y])) # map of genes to new genes # each entry in the list is a pair of genes of the same species # but different strains to be joined. map_genes2new_genes = [] # dictionary of merged genes. This is to ensure that no gene # is merged twice merged_genes = {} def count_genes(node_id): """record number of genes per species for each node This is done separately for each strain. The counts are aggregated for each species over strains by taking the maximum gene count per strain. This ignores any finer tree structure below a species node. """ node = tree.node(node_id) if node.succ: this_node_set = genes[node_id] # process non-leaf node for s in node.succ: # propagate: terminated nodes force upper nodes to terminate # (assigned to None). if not genes[s]: this_node_set = None break # check if node merges genes that are not part of the positive # set for strain in all_strains: if strain in map_strain2species: # merge genes from all children this_node_set[strain] = this_node_set[ strain].union(genes[s][strain]) if len(this_node_set[strain]) > 1: # more than two genes for a single species, so no # join this_node_set = None break elif strain not in map_strain2species and \ this_node_set[strain] > 0: this_node_set = None break if this_node_set is None: genes[node_id] = None return for strain_x, strain_y in pairs: if len(this_node_set[strain_x]) == 1 and len(this_node_set[strain_y]) == 1: species = map_strain2species[strain_x] gene_x, gene_y = tuple(this_node_set[strain_x])[0], tuple( this_node_set[strain_y])[0] # check if these to genes have already been merged or are # merged with other partners already # The merged genes are assigned the same node_id, if they have # been already merged. key1 = strain_x + gene_x key2 = strain_y + gene_y if key1 > key2: key1, key2 = key2, key1 merge = False if key1 in merged_genes and key2 in merged_genes: if merged_genes[key1] == merged_genes[key2]: merge = True elif key1 not in merged_genes and key2 not in merged_genes: merge = True merged_genes[key1] = node_id merged_genes[key2] = node_id if merge: map_genes2new_genes.append( (node_id, species, strain_x, gene_x, strain_y, gene_y)) # once two genes have been joined, they can not be remapped # further genes[node_id] = None return else: # process leaf strain, t, g, q = parseIdentifier(node.data.taxon, options) if strain in map_strain2species: genes[node_id][strain].add(g) else: # do not process nodes that do not need to be mapped genes[node_id] = None tree.dfs(tree.root, post_function=count_genes) return map_genes2new_genes
def extractSubtrees(tree, extract_species, options): """extract subtrees from tree. Splits a rooted tree at outgroups or at out-paralogs. Returns a list of clusters with its members belonging to each subtree. """ nin = [0] * TreeTools.GetSize(tree) nout = [0] * TreeTools.GetSize(tree) nstop = [False] * TreeTools.GetSize(tree) taxa = [set() for x in range(TreeTools.GetSize(tree))] otus = [set() for x in range(TreeTools.GetSize(tree))] clusters = [] def update_groups(node_id): node = tree.node(node_id) if node.succ == []: taxa[node_id] = set((extract_species(node.data.taxon),)) otus[node_id] = set((node.data.taxon,)) if extract_species(node.data.taxon) in options.outgroup_species: nout[node_id] = 1 else: nin[node_id] = 1 else: a, b = node.succ oa = nout[a] > 0 ob = nout[b] > 0 ia = nin[a] > 0 ib = nin[b] > 0 overlap = len(taxa[a].intersection(taxa[b])) > 0 # merge, if # * both have outgroups but no ingroups # * either have ingroups but no outgroups. # * one has outgroups and ingroups don't overlap merge = False if (oa and ob) and not (ia or ib): merge = True elif (ia or ib) and not (oa or ob): merge = True elif ((oa and not ob) or (ob and not oa)) and not overlap: merge = True # print node_id, a, b, "oa=",oa, "ob=", ob, "ia=", ia, "ib=",ib, # "ovl=",overlap, "merge=",merge if merge: nout[node_id] = sum([nout[x] for x in node.succ]) nin[node_id] = sum([nin[x] for x in node.succ]) taxa[node_id] = taxa[a].union(taxa[b]) otus[node_id] = otus[a].union(otus[b]) else: if ia and oa and ib and ob: # write two complete subtrees nout[node_id] = 0 nin[node_id] = 0 taxa[node_id] = set() otus[node_id] = set() clusters.append(otus[a]) clusters.append(otus[b]) elif ia and oa: # write a, keep b nout[node_id] = nout[b] nin[node_id] = nin[b] taxa[node_id] = taxa[b] otus[node_id] = otus[b] clusters.append(otus[a]) elif ib and ob: # write b, keep a nout[node_id] = nout[a] nin[node_id] = nin[a] taxa[node_id] = taxa[a] otus[node_id] = otus[a] clusters.append(otus[b]) elif not(ia and ib and oa and ob): # two empty subtrees merge nout[node_id] = 0 nin[node_id] = 0 taxa[node_id] = set() otus[node_id] = set() else: tree.display() print node_id, ia, ib, oa, ob raise "sanity check failed: unknown case." TreeTools.TreeDFS(tree, tree.root, post_function=update_groups) # special treatment of root if otus[tree.root]: if clusters: # add to previous cluster if only outgroups or ingroups # in root node oa = nout[tree.root] > 0 ia = nin[tree.root] > 0 if (oa and not ia) or (ia and not oa): clusters[-1] = clusters[-1].union(otus[tree.root]) else: clusters.append(otus[tree.root]) else: clusters.append(otus[tree.root]) return clusters
def getSpeciesTreeMergers(tree, full_map_strain2species, options): """merge strains to species. Simply rename all taxa of strains to the species. """ nnodes = TreeTools.GetSize(tree) + 1 map_strain2species = {} for n in tree.get_terminals(): node = tree.node(n) taxon = node.data.taxon if taxon in full_map_strain2species: map_strain2species[taxon] = full_map_strain2species[taxon] node.data.taxon = map_strain2species[taxon] if len(map_strain2species) == 0: return [] all_species = tree.get_taxa() mapped_species = set(map_strain2species.values()) species_at_node = [] for x in range(nnodes): g = {} for s in all_species: g[s] = 0 species_at_node.append(g) def count_species(node_id): """record species for each node """ node = tree.node(node_id) if node.succ: # process non-leaf node for s in node.succ: for species in all_species: species_at_node[node_id][ species] += species_at_node[s][species] else: # process leaf species = node.data.taxon species_at_node[node_id][species] = 1 tree.dfs(tree.root, post_function=count_species) # now merge all those that contain only a single species # proceed top-down nodes_to_skip = set() mergers = [] def merge_species(node_id): if node_id in nodes_to_skip: return total = sum(species_at_node[node_id].values()) for species in mapped_species: if species_at_node[node_id][species] <= 1 or \ species_at_node[node_id][species] != total: continue # merge species children = tree.get_leaves(node_id) for child in children: nodes_to_skip.add(child) mergers.append((node_id, children)) tree.dfs(tree.root, pre_function=merge_species) return mergers