Example #1
0
def main(argv=None):
    """script main.

    parses command line options in sys.argv, unless *argv* is given.
    """

    if argv is None:
        argv = sys.argv

    parser = E.OptionParser(version="%prog version: $Id: trees2tree.py 2782 2009-09-10 11:40:29Z andreas $",
                            usage=globals()["__doc__"])

    parser.add_option("-m", "--method", dest="method", type="choice",
                      choices=("counts", "min", "max", "sum", "mean", "median", "stddev", "non-redundant", "consensus",
                               "select-largest"),
                      help="aggregation function.")

    parser.add_option("-r", "--regex-id", dest="regex_id", type="string",
                      help="regex pattern to extract identifier from tree name for the selection functions.")

    parser.add_option("-w", "--write-values", dest="write_values", type="string",
                      help="if processing multiple trees, write values to file.")

    parser.add_option("-e", "--error-branchlength", dest="error_branchlength", type="float",
                      help="set branch length without counts to this value.")

    parser.set_defaults(
        method="mean",
        regex_id=None,
        filtered_branch_lengths=(-999.0, 999.0),
        write_values = None,
        error_branchlength = None,
        separator=":",
    )

    (options, args) = E.Start(parser, add_pipe_options=True)

    if options.loglevel >= 2:
        options.stdlog.write("# reading trees from stdin.\n")
        options.stdlog.flush()

    nexus = TreeTools.Newick2Nexus(sys.stdin)
    if options.loglevel >= 1:
        options.stdlog.write(
            "# read %i trees from stdin.\n" % len(nexus.trees))

    nskipped = 0
    ninput = len(nexus.trees)
    noutput = 0
    nerrors = 0

    if options.method == "non-redundant":
        # compute non-redudant trees
        template_trees = []
        template_counts = []
        ntree = 0
        for tree in nexus.trees:

            for x in range(0, len(template_trees)):
                is_compatible, reason = TreeTools.IsCompatible(
                    tree, template_trees[x])
                if is_compatible:
                    template_counts[x] += 1
                    break
            else:
                template_counts.append(1)
                template_trees.append(tree)

            if options.loglevel >= 2:
                options.stdlog.write(
                    "# tree=%i, ntemplates=%i\n" % (ntree, len(template_trees)))

            ntree += 1

        for x in range(0, len(template_trees)):
            if options.loglevel >= 1:
                options.stdlog.write("# tree: %i, counts: %i, percent=%5.2f\n" %
                                     (x, template_counts[x], template_counts[x] * 100.0 / ntotal))
            options.stdout.write(
                TreeTools.Tree2Newick(template_trees[x]) + "\n")

    elif options.method in ("select-largest",):
        # select one of the trees with the same name.
        clusters = {}
        for x in range(0, len(nexus.trees)):
            n = nexus.trees[x].name

            if options.regex_id:
                n = re.search(options.regex_id, n).groups()[0]

            if n not in clusters:
                clusters[n] = []
            clusters[n].append(x)

        new_trees = []

        for name, cluster in clusters.items():
            new_trees.append(
                getBestTree([nexus.trees[x] for x in cluster], options.method))

        for x in range(0, len(new_trees)):
            options.stdout.write(">%s\n" % new_trees[x].name)
            options.stdout.write(TreeTools.Tree2Newick(new_trees[x],) + "\n")
            noutput += 1

        nskipped = ntotal - noutput

    elif options.method == "consensus":

        phylip = WrapperPhylip.Phylip()
        phylip.setLogLevel(options.loglevel - 2)
        phylip.setProgram("consense")
        phylip_options = []
        phylip_options.append("Y")

        phylip.setOptions(phylip_options)
        phylip.setTrees(nexus.trees)

        result = phylip.run()

        options.stdout.write(
            "# consensus tree built from %i trees\n" % (phylip.mNInputTrees))
        options.stdout.write(
            TreeTools.Tree2Newick(result.mNexus.trees[0]) + "\n")
        noutput = 1

    else:
        if options.method in ("min", "max", "sum", "mean", "counts"):

            xtree = nexus.trees[0]
            for n in xtree.chain.keys():
                if xtree.node(n).data.branchlength in options.filtered_branch_lengths:
                    xtree.node(n).data.branchlength = 0
                ntotals = [1] * len(xtree.chain.keys())

            if options.method == "min":
                f = min
            elif options.method == "max":
                f = max
            elif options.method == "sum":
                f = lambda x, y: x + y
            elif options.method == "mean":
                f = lambda x, y: x + y
            elif options.method == "counts":
                f = lambda x, y: x + 1
                for n in xtree.chain.keys():
                    if xtree.node(n).data.branchlength not in options.filtered_branch_lengths:
                        xtree.node(n).data.branchlength = 1
                    else:
                        xtree.node(n).data.branchlength = 0
            else:
                raise "unknown option %s" % options.method

            for tree in nexus.trees[1:]:

                for n in tree.chain.keys():
                    if tree.node(n).data.branchlength not in options.filtered_branch_lengths:
                        xtree.node(n).data.branchlength = f(
                            xtree.node(n).data.branchlength, tree.node(n).data.branchlength)
                        ntotals[n] += 1

            if options.method == "mean":
                for n in xtree.chain.keys():
                    if ntotals[n] > 0:
                        xtree.node(n).data.branchlength = float(
                            xtree.node(n).data.branchlength) / ntotals[n]
                    else:
                        if options.error_branchlength is not None:
                            xtree.node(
                                n).data.branchlength = options.error_branchlength
                            if options.loglevel >= 1:
                                options.stdlog.write(
                                    "# no counts for node %i - set to %f\n" % (n, options.error_branchlength))
                                nerrors += 1
                        else:
                            raise "no counts for node %i" % n

        else:
            # collect all values for trees
            values = [[] for x in range(TreeTools.GetSize(nexus.trees[0]))]

            for tree in nexus.trees:
                for n, node in tree.chain.items():
                    if node.data.branchlength not in options.filtered_branch_lengths:
                        values[n].append(node.data.branchlength)

            tree = nexus.trees[0]
            for n, node in tree.chain.items():
                if len(values[n]) > 0:
                    if options.method == "stddev":
                        node.data.branchlength = scipy.std(values[n])
                    elif options.method == "median":
                        node.data.branchlength = scipy.median(values[n])
                else:
                    if options.error_branchlength is not None:
                        node.data.branchlength = options.error_branchlength
                        if options.loglevel >= 1:
                            options.stdlog.write(
                                "# no counts for node %i - set to %f\n" % (n, options.error_branchlength))
                            nerrors += 1
                    else:
                        raise "no counts for node %i" % n

            if options.write_values:
                outfile = open(options.write_values, "w")
                for n, node in tree.chain.items():
                    values[n].sort()
                    id = options.separator.join(
                        sorted(TreeTools.GetLeaves(tree, n)))
                    outfile.write("%s\t%s\n" %
                                  (id, ";".join(map(str, values[n]))))
                outfile.close()

        del nexus.trees[1:]
        options.stdout.write(TreeTools.Nexus2Newick(nexus) + "\n")
        noutput = 1

    if options.loglevel >= 1:
        options.stdlog.write("# ntotal=%i, nskipped=%i, noutput=%i, nerrors=%i\n" % (
            ninput, nskipped, noutput, nerrors))

    E.Stop()
Example #2
0
def getOrthologNodes(tree,
                     positive_set,
                     options,
                     selector="strict",
                     outgroups=None):
    """get all ortholog nodes in tree for species in positive_set.

    Depending on the selector function, different sets are returned:

    If selector is "strict", only strict orthologs are returned. These
    contain exactly one gene per species for all species in the positive_set.

    If selector is "degenerate", only degenerate orthologs are returned.
    These contain at least gene per species for species in the positive_set.

    Collect genes in tree for each species.

    Returns the node_id for which a set fulfills the criteria
    and the set for which it fulfills it.

    Avoid double counting: if you are interested in species A and B,
    any branches involving others species should be ignored. Make sure to only
    count once and not every time a discarded branch is removed.

    Thus, as soon as A and B merge, any node further up the tree have to be
    ignored.

    total_genes_function:   if true, node is recorded
    total_species_function: if true, iteration stops

    """

    nspecies = len(options.org2column)

    if selector == "strict":
        # strict orthologs: at most one gene per species
        exit_function = lambda num_genes_for_species: num_genes_for_species > 1
        keep_function = lambda num_genes_for_species: num_genes_for_species == 1
        total_genes_function = lambda num_genes_at_node, num_species_in_pattern: num_genes_at_node == num_species_in_pattern
        total_species_function = lambda num_species_at_node, num_species_in_pattern: num_species_at_node == num_species_in_pattern
        check_outgroup_function = lambda x: False
        negative_set = set()
    elif selector == "degenerate":
        # degenerate orthologs: any number of genes per species,
        exit_function = lambda num_genes_for_species: False
        keep_function = lambda num_genes_for_species: num_genes_for_species > 0
        total_genes_function = lambda num_genes_at_node, num_species_in_pattern: num_genes_at_node > num_species_in_pattern
        total_species_function = lambda num_species_at_node, num_species_in_pattern: num_species_at_node == num_species_in_pattern
        check_outgroup_function = lambda x: False
        negative_set = set()
    elif selector == "lineage":
        # lineage specific duplications: at least 1 gene
        exit_function = lambda num_genes_for_species: False
        keep_function = lambda num_genes_for_species: num_genes_for_species > 1
        total_genes_function = lambda num_genes_at_node, num_species_in_pattern: num_genes_at_node >= num_species_in_pattern
        total_species_function = lambda num_species_at_node, num_species_in_pattern: False
        check_outgroup_function = lambda x: False
        negative_set = set(range(nspecies)).difference(positive_set)
    elif selector == "any":
        # any number of orthologs, including
        # orphans
        exit_function = lambda num_genes_for_species: False
        keep_function = lambda num_genes_for_species: True
        total_genes_function = lambda num_genes_at_node, num_species_in_pattern: True
        total_species_function = lambda num_species_at_node, num_species_in_pattern: num_species_at_node == num_species_in_pattern
        check_outgroup_function = lambda x: False
        negative_set = set()
    elif selector == "outgroup":
        # group selector
        exit_function = lambda num_genes_for_species: False
        keep_function = lambda num_genes_for_species: num_genes_for_species > 0
        total_genes_function = lambda num_genes_at_node, num_species_in_pattern: False
        total_species_function = lambda num_species_at_node, num_species_in_pattern: False
        # check for outgrup: needs to have outgroup and at least one other species
        # ie.: sum of all genes in outgroups larger than sum of all genes in
        # all species
        if not outgroups:
            raise "usage error: please supply outgroups if 'outgroup'-selector is chosen."
        check_outgroup_function = lambda genes: 0 < sum(
            [len(genes[x])
             for x in outgroups]) < sum(map(lambda x: len(x), genes))
        negative_set = set()
    else:
        raise "unknown selector %s" % selector

    # work here: set genes[node_id] to None,
    # 1. if the gene count for a species of interest is > 1
    # 2. if the gene count for all species of interest is 1 in
    # the child node.

    if options.loglevel >= 5:
        options.stdlog.write("# gene tree\n")
        tree.display()

    n = TreeTools.GetSize(tree) + 1
    genes = []
    for x in range(n):
        genes.append([set() for x in range(nspecies)])

    ortholog_nodes = []

    def count_genes(node_id):
        """record number of genes per species for each node
        """
        node = tree.node(node_id)

        if options.loglevel >= 6:
            options.stdlog.write("# node_id=%i\n" % node_id)
            if options.loglevel >= 10:
                options.stdlog.write("# sets=%s\n" % (str(genes)))

        # species in pattern
        num_species_in_pattern = len(positive_set)

        if node.succ:
            # process non-leaf node
            for s in node.succ:

                # propagate: terminated nodes force upper nodes to terminate
                # (assigned to None).
                if not genes[s]:
                    genes[node_id] = None
                    return

                # total number of genes at node
                num_genes_at_node = 0
                # total number of species at node
                num_species_at_node = 0

                # compute new gene set for each species at node
                for x in positive_set:
                    genes[node_id][x] = genes[node_id][x].union(genes[s][x])

                    num_genes_for_species = len(genes[node_id][x])
                    if exit_function(num_genes_for_species):
                        genes[node_id] = None
                        return
                    num_genes_at_node += num_genes_for_species
                    if num_genes_for_species:
                        num_species_at_node += 1

            if options.loglevel >= 6:
                print "node=", node_id, "species_at_node", num_species_at_node, "genes_at_node=", num_genes_at_node, \
                    "num_genes_for_species=", num_genes_for_species, "ngenes=", sum(
                        map(lambda x: len(x), genes[node_id]))
                options.stdlog.write("# genes at node %i\t%s\n" %
                                     (node_id, genes[node_id]))
                if outgroups:
                    print sum([len(genes[node_id][x]) for x in outgroups])
                    print check_outgroup_function(genes[node_id])

            # check stop criterion
            if total_species_function(num_species_at_node,
                                      num_species_in_pattern):
                # check if positive requirements are fulfilled
                for x in positive_set:
                    if not keep_function(len(genes[node_id][x])):
                        if options.loglevel >= 6:
                            options.stdlog.write(
                                "# keep function false for species %i\n" % x)
                        break
                else:
                    if total_genes_function(num_genes_at_node,
                                            num_species_in_pattern):
                        if options.loglevel >= 6:
                            options.stdlog.write("# recording node %i\n" % x)
                        ortholog_nodes.append((node_id, genes[node_id]))
                genes[node_id] = None
                return
            elif check_outgroup_function(genes[node_id]):
                ortholog_nodes.append((node_id, genes[node_id]))
                genes[node_id] = None
                return
            elif negative_set:
                if total_genes_function(num_genes_at_node,
                                        num_species_in_pattern):
                    if options.loglevel >= 6:
                        options.stdlog.write("# recording node %i\n" % node_id)
                    ortholog_nodes.append((node_id, genes[node_id]))

        else:
            # process leaf
            s, t, g, q = parseIdentifier(node.data.taxon, options)
            c = options.org2column[s]
            if c in positive_set:
                genes[node_id][c].add(g)
            elif c in negative_set:
                genes[node_id] = None

    tree.dfs(tree.root, post_function=count_genes)

    return ortholog_nodes
Example #3
0
def getMergers(tree, map_strain2species, options):
    """merge strains to species.

    returns the new tree with species merged and
    a dictionary of genes including the genes that have been merged.

    Currently, only binary merges are supported.
    """

    n = TreeTools.GetSize(tree) + 1
    all_strains = map_strain2species.keys()
    all_species = map_strain2species.values()

    genes = []
    for x in range(n):
        g = {}
        for s in all_strains:
            g[s] = set()
        genes.append(g)

    # build list of species pairs that can be joined.
    map_species2strain = IOTools.getInvertedDictionary(map_strain2species)
    pairs = []

    for species, strains in map_species2strain.items():
        for x in range(len(strains)):
            for y in range(0, x):
                pairs.append((strains[x], strains[y]))

    # map of genes to new genes
    # each entry in the list is a pair of genes of the same species
    # but different strains to be joined.
    map_genes2new_genes = []

    # dictionary of merged genes. This is to ensure that no gene
    # is merged twice
    merged_genes = {}

    def count_genes(node_id):
        """record number of genes per species for each node

        This is done separately for each strain. The counts are aggregated for each species
        over strains by taking the maximum gene count per strain. This ignores any finer
        tree structure below a species node.
        """
        node = tree.node(node_id)

        if node.succ:
            this_node_set = genes[node_id]
            # process non-leaf node
            for s in node.succ:

                # propagate: terminated nodes force upper nodes to terminate
                # (assigned to None).
                if not genes[s]:
                    this_node_set = None
                    break

                # check if node merges genes that are not part of the positive
                # set
                for strain in all_strains:
                    if strain in map_strain2species:
                        # merge genes from all children
                        this_node_set[strain] = this_node_set[
                            strain].union(genes[s][strain])

                        if len(this_node_set[strain]) > 1:
                            # more than two genes for a single species, so no
                            # join
                            this_node_set = None
                            break

                    elif strain not in map_strain2species and \
                            this_node_set[strain] > 0:
                        this_node_set = None
                        break

            if this_node_set is None:
                genes[node_id] = None
                return

            for strain_x, strain_y in pairs:
                if len(this_node_set[strain_x]) == 1 and len(this_node_set[strain_y]) == 1:
                    species = map_strain2species[strain_x]
                    gene_x, gene_y = tuple(this_node_set[strain_x])[0], tuple(
                        this_node_set[strain_y])[0]

                    # check if these to genes have already been merged or are
                    # merged with other partners already
                    # The merged genes are assigned the same node_id, if they have
                    # been already merged.
                    key1 = strain_x + gene_x
                    key2 = strain_y + gene_y
                    if key1 > key2:
                        key1, key2 = key2, key1

                    merge = False
                    if key1 in merged_genes and key2 in merged_genes:
                        if merged_genes[key1] == merged_genes[key2]:
                            merge = True
                    elif key1 not in merged_genes and key2 not in merged_genes:
                        merge = True
                        merged_genes[key1] = node_id
                        merged_genes[key2] = node_id

                    if merge:
                        map_genes2new_genes.append(
                            (node_id, species, strain_x, gene_x, strain_y, gene_y))

                    # once two genes have been joined, they can not be remapped
                    # further
                    genes[node_id] = None
                    return
        else:
            # process leaf
            strain, t, g, q = parseIdentifier(node.data.taxon, options)

            if strain in map_strain2species:
                genes[node_id][strain].add(g)
            else:
                # do not process nodes that do not need to be mapped
                genes[node_id] = None

    tree.dfs(tree.root, post_function=count_genes)

    return map_genes2new_genes
Example #4
0
def extractSubtrees(tree, extract_species, options):
    """extract subtrees from tree.

    Splits a rooted tree at outgroups or at out-paralogs.

    Returns a list of clusters with its members belonging to each
    subtree.
    """

    nin = [0] * TreeTools.GetSize(tree)
    nout = [0] * TreeTools.GetSize(tree)
    nstop = [False] * TreeTools.GetSize(tree)
    taxa = [set() for x in range(TreeTools.GetSize(tree))]
    otus = [set() for x in range(TreeTools.GetSize(tree))]

    clusters = []

    def update_groups(node_id):
        node = tree.node(node_id)

        if node.succ == []:
            taxa[node_id] = set((extract_species(node.data.taxon),))
            otus[node_id] = set((node.data.taxon,))
            if extract_species(node.data.taxon) in options.outgroup_species:
                nout[node_id] = 1
            else:
                nin[node_id] = 1
        else:
            a, b = node.succ
            oa = nout[a] > 0
            ob = nout[b] > 0
            ia = nin[a] > 0
            ib = nin[b] > 0
            overlap = len(taxa[a].intersection(taxa[b])) > 0

            # merge, if
            # * both have outgroups but no ingroups
            # * either have ingroups but no outgroups.
            # * one has outgroups and ingroups don't overlap
            merge = False
            if (oa and ob) and not (ia or ib):
                merge = True
            elif (ia or ib) and not (oa or ob):
                merge = True
            elif ((oa and not ob) or (ob and not oa)) and not overlap:
                merge = True

            # print node_id, a, b, "oa=",oa, "ob=", ob, "ia=", ia, "ib=",ib,
            # "ovl=",overlap, "merge=",merge

            if merge:
                nout[node_id] = sum([nout[x] for x in node.succ])
                nin[node_id] = sum([nin[x] for x in node.succ])
                taxa[node_id] = taxa[a].union(taxa[b])
                otus[node_id] = otus[a].union(otus[b])
            else:
                if ia and oa and ib and ob:
                    # write two complete subtrees
                    nout[node_id] = 0
                    nin[node_id] = 0
                    taxa[node_id] = set()
                    otus[node_id] = set()
                    clusters.append(otus[a])
                    clusters.append(otus[b])
                elif ia and oa:
                    # write a, keep b
                    nout[node_id] = nout[b]
                    nin[node_id] = nin[b]
                    taxa[node_id] = taxa[b]
                    otus[node_id] = otus[b]
                    clusters.append(otus[a])
                elif ib and ob:
                    # write b, keep a
                    nout[node_id] = nout[a]
                    nin[node_id] = nin[a]
                    taxa[node_id] = taxa[a]
                    otus[node_id] = otus[a]
                    clusters.append(otus[b])
                elif not(ia and ib and oa and ob):
                    # two empty subtrees merge
                    nout[node_id] = 0
                    nin[node_id] = 0
                    taxa[node_id] = set()
                    otus[node_id] = set()
                else:
                    tree.display()
                    print node_id, ia, ib, oa, ob
                    raise "sanity check failed: unknown case."

    TreeTools.TreeDFS(tree, tree.root,
                      post_function=update_groups)

    # special treatment of root
    if otus[tree.root]:
        if clusters:
            # add to previous cluster if only outgroups or ingroups
            # in root node
            oa = nout[tree.root] > 0
            ia = nin[tree.root] > 0
            if (oa and not ia) or (ia and not oa):
                clusters[-1] = clusters[-1].union(otus[tree.root])
            else:
                clusters.append(otus[tree.root])
        else:
            clusters.append(otus[tree.root])

    return clusters
Example #5
0
def getSpeciesTreeMergers(tree, full_map_strain2species, options):
    """merge strains to species.

    Simply rename all taxa of strains to the species.
    """

    nnodes = TreeTools.GetSize(tree) + 1

    map_strain2species = {}
    for n in tree.get_terminals():
        node = tree.node(n)
        taxon = node.data.taxon
        if taxon in full_map_strain2species:
            map_strain2species[taxon] = full_map_strain2species[taxon]
            node.data.taxon = map_strain2species[taxon]

    if len(map_strain2species) == 0:
        return []

    all_species = tree.get_taxa()
    mapped_species = set(map_strain2species.values())

    species_at_node = []
    for x in range(nnodes):
        g = {}
        for s in all_species:
            g[s] = 0
        species_at_node.append(g)

    def count_species(node_id):
        """record species for each node
        """
        node = tree.node(node_id)
        if node.succ:
            # process non-leaf node
            for s in node.succ:
                for species in all_species:
                    species_at_node[node_id][
                        species] += species_at_node[s][species]
        else:
            # process leaf
            species = node.data.taxon
            species_at_node[node_id][species] = 1

    tree.dfs(tree.root, post_function=count_species)

    # now merge all those that contain only a single species
    # proceed top-down

    nodes_to_skip = set()
    mergers = []

    def merge_species(node_id):

        if node_id in nodes_to_skip:
            return

        total = sum(species_at_node[node_id].values())
        for species in mapped_species:
            if species_at_node[node_id][species] <= 1 or \
                    species_at_node[node_id][species] != total:
                        continue

            # merge species
            children = tree.get_leaves(node_id)
            for child in children:
                nodes_to_skip.add(child)

            mergers.append((node_id, children))

    tree.dfs(tree.root, pre_function=merge_species)

    return mergers