예제 #1
0
    def tree_style_with_data(self, data={}, order=None, force_topology=False):
        """
        newick: text or file
        render_in: default %%inline for notebooks
        data: {leaf -> col -> value}
        """

        ts = TreeStyle()
        if data:
            if not order:
                # order = list(data.keys())
                first = data.keys()
                first = list(first)[0]
                order = data[first].keys()
            ts.show_leaf_name = True
            ts.draw_guiding_lines = True
            ts.force_topology = force_topology

        for i, x in enumerate(order):
            tf = TextFace(x)
            tf.margin_left = 5
            ts.aligned_header.add_face(tf, column=i)
        if data:
            for leaf in self.tree.get_leaves():
                for i, col in enumerate(order):
                    tf = TextFace(data[leaf.name][col])
                    tf.margin_left = 5
                    leaf.add_face(tf, column=i, position="aligned")

        return ts
예제 #2
0
 def get_tree_style(self):
     ts = TreeStyle()
     ts.layout_fn = self.custom_layout
     ts.show_leaf_name = False
     ts.draw_guiding_lines = True
     #ts.guiding_lines_type = 1
     self._treestyle = ts
     return ts
예제 #3
0
 def run_action_change_style(self, tree, a_data):
     #print "action change style called.."        
     if tree.tree_style == self._treestyle:
        ts2 = TreeStyle()
        ts2.layout_fn = self.custom_layout
        ts2.show_leaf_name = False
        ts2.draw_guiding_lines = True
        ts2.guiding_lines_type = 0 #solid line
        ts2.guiding_lines_color = a_data
        tree.tree_style = ts2
        self._treestyle = ts2
     else:
        tree.tree_style = self._treestyle
def main(args):

    # STEP 1: Set up logger
    log = logging.getLogger(__name__)
    coloredlogs.install(fmt='%(asctime)s [%(levelname)s] %(message)s', level='DEBUG', logger=log)

    # STEP 2: Retrieve and/or update localized NCBI Taxonomy database
    ncbi = NCBITaxa()
    if (time.time() - os.path.getmtime(os.path.join(Path.home(), ".etetoolkit/taxa.sqlite"))) > 604800:
        ncbi.update_taxonomy_database()

    # STEP 3: Prune species-level tree to family-level

        # Step 3.1 Read tree from input file
    log.debug("Loading Tree...")
    t = Tree(args.infn, format=5)
        # STEP 3.2: Add species names to species_set_from_tree set
    log.debug("Gathering species(leaf) names...")
    species_set_from_tree = set()
    for leaf in t.iter_leaves():
        species_set_from_tree.add(leaf.name.replace("_"," "))
        # STEP 3.3: Assign species to families
    log.debug("Constructing dict of species in family...")
    species_in_family = get_species_in_family(species_set_from_tree, ncbi)
        # STEP 3.4: Prune the tree
    log.debug("Pruning Tree to family level...")
    prune_to_family(t, species_in_family)

    # STEP 4: Calculate counts of species per family and plastid genome entries per family and attach them to the tree leaves

        # STEP 4.1: Read plastid genome information from input table
    species_list_from_table = get_species_list_from_table(args.tablefn)
        # STEP 4.2: Count plastid genome entries per family
    log.debug("Counting plastid genome entries per family...")
    genome_count_per_family = get_genome_count_per_family(species_list_from_table, species_in_family)
        # STEP 4.3: Attach counts to tree leaves
    log.debug("Attaching counts to Tree...")
    attach_counts_to_tree(t, genome_count_per_family, get_species_count_per_family(species_in_family))

    # STEP 5: Set TreeStyle and render tree
    ts = TreeStyle()
    ts.mode = "c"
    ts.draw_guiding_lines = True
    ts.show_leaf_name = False
    log.debug("Rendering Tree...")
    t.render(args.outfn, w=10000, h=10000, tree_style=ts)
예제 #5
0
def plot_tree_barplot(tree_file, taxon2mlst, header_list):
    '''

    display one or more barplot

    :param tree_file:
    :param taxon2value_list:
    :param exclude_outgroup:
    :param bw_scale:
    :param barplot2percentage: list of bool to indicates if the number are percentages and the range should be set to 0-100

    :return:
    '''

    import matplotlib.cm as cm
    from matplotlib.colors import rgb2hex
    import matplotlib as mpl

    mlst_list = list(set(taxon2mlst.values()))
    mlst2color = dict(zip(mlst_list, get_spaced_colors(len(mlst_list))))
    mlst2color['-'] = 'white'

    if isinstance(tree_file, Tree):
        t1 = tree_file
    else:
        t1 = Tree(tree_file)

    # Calculate the midpoint node
    R = t1.get_midpoint_outgroup()
    # and set it as tree outgroup
    t1.set_outgroup(R)

    tss = TreeStyle()
    value = 1
    tss.draw_guiding_lines = True
    tss.guiding_lines_color = "gray"
    tss.show_leaf_name = False

    cmap = cm.YlGnBu  #YlOrRd#OrRd

    scale_list = []
    max_value_list = []

    for i, lf in enumerate(t1.iter_leaves()):

        #if taxon2description[lf.name] == 'Pirellula staleyi DSM 6068':
        #    lf.name = 'Pirellula staleyi DSM 6068'
        #    continue
        if i == 0:
            # header

            col_add = 0

            #lf.add_face(n, column, position="aligned")
            n = TextFace('MLST')
            n.margin_top = 1
            n.margin_right = 2
            n.margin_left = 2
            n.margin_bottom = 1
            n.rotation = 90
            n.inner_background.color = "white"
            n.opacity = 1.
            n.hz_align = 2
            n.vt_align = 2

            tss.aligned_header.add_face(n, col_add + 1)

        try:
            #if lf.name in leaf2mlst or int(lf.name) in leaf2mlst:
            n = TextFace(' %s ' % taxon2mlst[int(lf.name)])
            n.inner_background.color = 'white'
            m = TextFace('  ')
            m.inner_background.color = mlst2color[taxon2mlst[int(lf.name)]]
        except:
            n = TextFace(' na ')
            n.inner_background.color = "grey"
            m = TextFace('    ')
            m.inner_background.color = "white"

        n.opacity = 1.
        n.margin_top = 2
        n.margin_right = 2
        n.margin_left = 0
        n.margin_bottom = 2

        m.margin_top = 2
        m.margin_right = 0
        m.margin_left = 2
        m.margin_bottom = 2

        lf.add_face(m, 0, position="aligned")
        lf.add_face(n, 1, position="aligned")

        n = TextFace(lf.name, fgcolor="black", fsize=12, fstyle='italic')
        lf.add_face(n, 0)

    for n in t1.traverse():
        nstyle = NodeStyle()
        if n.support < 1:
            nstyle["fgcolor"] = "black"
            nstyle["size"] = 6
            n.set_style(nstyle)
        else:
            nstyle["fgcolor"] = "red"
            nstyle["size"] = 0
            n.set_style(nstyle)

    return t1, tss
    else:
        F = faces.TextFace(mynode.name,fsize=20)
        faces.add_face_to_node(F,mynode,0,position="aligned")

#Plot Pie Chart	
ts = TreeStyle()
ts.show_leaf_name = False

ts.layout_fn = phyparts_pie_layout
nstyle = NodeStyle()
nstyle["size"] = 0
for n in plot_tree.traverse():
	n.set_style(nstyle)
	n.img_style["vt_line_width"] = 0

ts.draw_guiding_lines = True
ts.guiding_lines_color = "black"
ts.guiding_lines_type = 0
ts.scale = 30
ts.branch_vertical_margin = 10
plot_tree.convert_to_ultrametric()
plot_tree.ladderize(direction=1)    
my_svg = plot_tree.render(args.svg_name,tree_style=ts,w=595,dpi=300)

if args.show_nodes:
	node_style = TreeStyle()
	node_style.show_leaf_name=False
	node_style.layout_fn = node_text_layout
	plot_tree.show(tree_style=node_style)

     
예제 #7
0
def plot_phylum_counts(NOG_id,
                       rank='phylum',
                       colapse_low_species_counts=4,
                       remove_unlassified=True):
    '''

    1. get phylum tree
    2. foreach species => get phylum
    3. build phylum2count dictionnary
    3. plot barchart

    # merge eukaryotes into 5 main clades
    # merge virus as a single clade


    ATTENTION: no-rank groups and no-rank species...

    '''

    import MySQLdb
    import os
    from chlamdb.biosqldb import manipulate_biosqldb
    from ete3 import NCBITaxa, Tree, TextFace, TreeStyle, StackedBarFace
    ncbi = NCBITaxa()

    sqlpsw = os.environ['SQLPSW']
    conn = MySQLdb.connect(
        host="localhost",  # your host, usually localhost
        user="******",  # your username
        passwd=sqlpsw,  # your password
        db="eggnog")  # name of the data base
    cursor = conn.cursor()

    sql = 'select * from eggnog.leaf2n_genomes_%s' % rank

    cursor.execute(sql, )
    leaf_taxon2n_species = manipulate_biosqldb.to_dict(cursor.fetchall())

    leaf_taxon2n_species_with_domain = get_NOG_taxonomy(NOG_id, rank)

    sql = 'select phylogeny from eggnog.phylogeny where rank="%s"' % (rank)

    cursor.execute(sql, )
    tree = Tree(cursor.fetchall()[0][0], format=1)

    sql = 'select * from eggnog.taxid2label_%s' % rank
    cursor.execute(sql, )

    taxon_id2scientific_name_and_rank = manipulate_biosqldb.to_dict(
        cursor.fetchall())
    taxon_id2scientific_name_and_rank = {
        str(k): v
        for k, v in taxon_id2scientific_name_and_rank.items()
    }

    tss = TreeStyle()
    tss.draw_guiding_lines = True
    tss.guiding_lines_color = "blue"

    keep = []
    for lf in tree.iter_leaves():
        # n genomes

        if remove_unlassified:
            label = taxon_id2scientific_name_and_rank[str(lf.name)][0]
            if 'unclassified' in label:
                continue

        n_genomes = int(leaf_taxon2n_species[lf.name])
        if n_genomes > colapse_low_species_counts:
            keep.append(lf.name)
    print('number of leaaves:', len(keep))

    tree.prune(keep)

    header_list = ['Rank', 'N genomes', 'N with %s' % NOG_id, 'Percentage']
    for col, header in enumerate(header_list):

        n = TextFace('%s' % (header))
        n.margin_top = 0
        n.margin_right = 1
        n.margin_left = 20
        n.margin_bottom = 1
        n.rotation = 270
        n.hz_align = 2
        n.vt_align = 2
        n.inner_background.color = "white"
        n.opacity = 1.
        tss.aligned_header.add_face(n, col)

    for lf in tree.iter_leaves():
        # n genomes

        n_genomes = int(leaf_taxon2n_species[lf.name])
        if n_genomes <= colapse_low_species_counts:
            continue

        n = TextFace('  %s ' % str(leaf_taxon2n_species[lf.name]))
        n.margin_top = 1
        n.margin_right = 1
        n.margin_left = 0
        n.margin_bottom = 1
        n.fsize = 7
        n.inner_background.color = "white"
        n.opacity = 1.
        lf.add_face(n, 2, position="aligned")

        # n genomes with domain
        try:
            m = TextFace('  %s ' %
                         str(leaf_taxon2n_species_with_domain[lf.name]))
        except:
            m = TextFace('  0 ')
        m.margin_top = 1
        m.margin_right = 1
        m.margin_left = 0
        m.margin_bottom = 1
        m.fsize = 7
        m.inner_background.color = "white"
        m.opacity = 1.
        lf.add_face(m, 3, position="aligned")

        # rank
        ranks = ncbi.get_rank([lf.name])
        try:
            r = ranks[max(ranks.keys())]
        except:
            r = '-'
        n = TextFace('  %s ' % r, fsize=14, fgcolor='red')
        n.margin_top = 1
        n.margin_right = 1
        n.margin_left = 0
        n.margin_bottom = 1
        n.fsize = 7
        n.inner_background.color = "white"
        n.opacity = 1.
        lf.add_face(n, 1, position="aligned")

        # percent with target domain
        try:
            percentage = (float(leaf_taxon2n_species_with_domain[lf.name]) /
                          float(leaf_taxon2n_species[lf.name])) * 100
        except:
            percentage = 0
        m = TextFace('  %s ' % str(round(percentage, 2)))
        m.fsize = 1
        m.margin_top = 1
        m.margin_right = 1
        m.margin_left = 0
        m.margin_bottom = 1
        m.fsize = 7
        m.inner_background.color = "white"
        m.opacity = 1.
        lf.add_face(m, 4, position="aligned")

        b = StackedBarFace([percentage, 100 - percentage],
                           width=100,
                           height=10,
                           colors=["#7fc97f", "white"])
        b.rotation = 0
        b.inner_border.color = "grey"
        b.inner_border.width = 0
        b.margin_right = 15
        b.margin_left = 0
        lf.add_face(b, 5, position="aligned")

        n = TextFace('%s' % taxon_id2scientific_name_and_rank[str(lf.name)][0],
                     fgcolor="black",
                     fsize=9)  # , fstyle = 'italic'

        lf.name = " %s (%s)" % (taxon_id2scientific_name_and_rank[str(
            lf.name)][0], str(lf.name))
        n.margin_right = 10
        lf.add_face(n, 0)

    tss.show_leaf_name = False

    for node in tree.traverse("postorder"):
        try:
            r = taxon_id2scientific_name_and_rank[str(node.name)][1]
        except:
            pass
        try:
            if r in ['phylum', 'superkingdom', 'class', 'subphylum'
                     ] or taxon_id2scientific_name_and_rank[str(
                         node.name)][0] in ['FCB group']:

                hola = TextFace(
                    "%s" %
                    (taxon_id2scientific_name_and_rank[str(node.name)][0]))
                node.add_face(hola, column=0, position="branch-top")
        except:
            pass
    return tree, tss
예제 #8
0
def make_tree_ali_detect_combi(g_tree,
                               ali_nf,
                               Out,
                               hist_up="",
                               x_values=[],
                               hp=[],
                               dict_benchmark={},
                               reorder=False,
                               det_tool=False,
                               title=""):
    reptree = g_tree.reptree
    cz_nodes = g_tree.cz_nodes
    ### Tree

    ## Tree style
    phylotree_style = TreeStyle()
    phylotree_style.show_leaf_name = False
    phylotree_style.show_branch_length = False
    phylotree_style.draw_guiding_lines = True
    phylotree_style.min_leaf_separation = 1

    ## For noisy tree
    cz_nodes_s = {}
    if cz_nodes:
        cols = [
            "#008000", "#800080", "#007D80", "#9CA1A2", "#A52A2A", "#ED8585",
            "#FF8EAD", "#8EB1FF", "#FFE4A1", "#ADA1FF"
        ]
        col_i = 0
        for Cz in cz_nodes.keys():
            cz_nodes_s[Cz] = NodeStyle()
            cz_nodes_s[Cz]["fgcolor"] = cols[col_i]
            cz_nodes_s[Cz]["size"] = 5
            cz_nodes_s[Cz]["hz_line_width"] = 2
            cz_nodes_s[Cz]["vt_line_width"] = 2
            cz_nodes_s[Cz]["vt_line_color"] = cols[col_i]
            cz_nodes_s[Cz]["hz_line_color"] = cols[col_i]
            col_i += 1

    sim_root_ND = g_tree.outgroup_ND

    def my_layout(node):
        ## Sequence name
        F = TextFace(node.name, tight_text=True)
        add_face_to_node(F, node, column=0, position="aligned")

        ## Sequence motif
        if node.is_leaf():
            motifs_n = []
            box_color = "black"
            opacity = 1
            if node.T == "True" or node.C == "True":
                motifs_n.append([
                    0,
                    len(node.sequence), "[]", 10, 12, box_color, box_color,
                    None
                ])

            motifs_n.append(
                [0, len(node.sequence), "seq", 10, 10, None, None, None])
            seq_face = SeqMotifFace(seq=node.sequence,
                                    seqtype='aa',
                                    seq_format='seq',
                                    fgcolor=box_color,
                                    motifs=motifs_n)
            seq_face.overlaping_motif_opacity = opacity
            add_face_to_node(seq_face, node, column=1, position='aligned')

        ## Nodes style
        if det_tool and node.T == "True":
            node.set_style(nstyle_T_sim)
            add_t(node)
        elif det_tool and node.C == "True":
            node.set_style(nstyle_C_sim)
        elif det_tool:
            node.set_style(nstyle)
        #if not det_tool no background
        elif node.T == "True" and not int(
                node.ND) in g_tree.conv_events.nodesWithTransitions_est:
            node.set_style(nstyle_T_sim)
            add_t(node)
        elif node.T == "True" and int(
                node.ND) in g_tree.conv_events.nodesWithTransitions_est:
            node.set_style(nstyle_T_sim_est)
            add_t(node)
        elif int(node.ND) in g_tree.conv_events.nodesWithTransitions_est:
            node.set_style(nstyle_T_est)
        elif node.C == "True" and not int(
                node.ND) in g_tree.conv_events.nodesWithConvergentModel_est:
            node.set_style(nstyle_C_sim)
        elif node.C == "True" and int(
                node.ND) in g_tree.conv_events.nodesWithConvergentModel_est:
            node.set_style(nstyle_C_sim_est)
        elif int(node.ND) in g_tree.conv_events.nodesWithConvergentModel_est:
            node.set_style(nstyle_C_est)
        elif cz_nodes_s and node.Cz != "False":
            node.set_style(cz_nodes_s[int(node.Cz)])
            if int(node.ND) == int(cz_nodes[int(node.Cz)][0]):
                add_t(node)
        else:
            node.set_style(nstyle)

        if int(node.ND) == sim_root_ND and not det_tool:
            add_sim_root(node)

    phylotree_style.layout_fn = my_layout

    # Get tree dimensions
    tree_nf = g_tree.annotated_tree_fn_est
    logger.debug("tree_nf: %s", tree_nf)
    t = PhyloTree(tree_nf)
    t.link_to_alignment(ali_nf)
    t.render(Out)
    tree_face = TreeFace(t, phylotree_style)
    tree_face.update_items()
    tree_h = tree_face._height()
    tree_w = tree_face._width()

    ### X axes:
    if not x_values:  # Complete representation
        x_values_up = [
            x + 1 for x in range(0, len(dict_benchmark.values()[0]))
        ]
        inter = 5
    else:  # Filtered representation
        x_values_up = x_values
        inter = 1

    ### Histogram up
    if hist_up in ["PCOC", "PC", "OC", "Topological", "Identical"]:
        header_hist_value_up = 'Posterior probability (' + hist_up.upper(
        ) + ' model)'
        hist_value_up = dict_benchmark[hist_up]
        # Define emphased lines
        hlines = [0.8, 0.9, 0.99]
        hlines_col = ['#EFDB00', 'orange', 'red']

        # Type of representation
        kind = 'stick'  # bar/curve/stick

        y_values_up = hist_value_up

        hist = SequencePlotFace_mod(y_values_up,
                                    hlines=hlines,
                                    hlines_col=hlines_col,
                                    kind=kind,
                                    header=header_hist_value_up,
                                    height=40,
                                    col_width=10,
                                    ylim=[0, 1])

        hist.hp = hp
        hist.x_values = x_values_up
        hist.x_inter_values = inter
        hist.set_sticks_color()

        if reorder:
            # draw colored boxes
            hist.put_colored_boxes = (True, tree_h + 40)

        phylotree_style.aligned_header.add_face(hist, 1)

    ### Rect all model:
    sequencescorebox = SequenceScoreFace(dict_benchmark, col_width=10)
    sequencescorebox.hp = hp
    sequencescorebox.x_values = x_values_up
    sequencescorebox.x_inter_values = inter
    if not hist_up in ["PCOC", "PC", "OC", "Topological", "Identical"]:
        sequencescorebox.x_axis = True

    phylotree_style.aligned_header.add_face(sequencescorebox, 1)

    if title:
        phylotree_style.title.add_face(TextFace(title), column=0)

    tree_nf = reptree + "/annotated_tree.nhx"
    logger.debug("tree_nf: %s", tree_nf)

    res = t.render(Out, tree_style=phylotree_style)
    del t
    return (res)
예제 #9
0
    ):  # only iterate over columns that are not gaps in target seq
        if keySeq[i] != "-":  # meaning anything except gaps
            mappedCols[
                keyIndex] = i + 1  # map the alignment column to the key column. Gotta shift the index!
            keyIndex += 1

    return mappedCols


## ETE3 TREE-VIZ FUNCTIONS ##

# basic tree style
tree_style = TreeStyle()
tree_style.show_leaf_name = False
tree_style.show_branch_length = False
tree_style.draw_guiding_lines = True
tree_style.complete_branch_lines_when_necessary = True

# make tree grow upward
tree_style.rotation = 270
# and make it appear ultrametric (which it is!)
tree_style.optimal_scale_level = "full"

# internal node style
nstyle = NodeStyle()
nstyle["fgcolor"] = "black"
nstyle["size"] = 0

# terminal node style
nstyle_L = NodeStyle()
nstyle["fgcolor"] = "black"
def bub_tree(tree, fasta, outfile1, root, types, c_dict, show, size, colours,
             field1, field2, scale, multiplier, dna):
    """
    :param tree: tree object from ete
    :param fasta: the fasta file used to make the tree
    :param outfile1: outfile suffix
    :param root: sequence name to use as root
    :param types: tree type: circular (c) or rectangle (r)
    :param c_dict: dictionary mapping colour to time point (from col_map)
    :param show: show the tree in a gui (y/n)
    :param size: scale the terminal nodes by frequency information (y/n)
    :param colours: if using a matched fasta file, colour the sequence by charge/IUPAC
    :param field1: the field that contains the size/frequency value
    :param field2: the field that contains the size/frequency value
    :param scale: how much to scale the x axis
    :param multiplier
    :param dna true/false, is sequence a DNA sequence?
    :param t_list list of time points
    :return: None, outputs svg/pdf image of the tree
    """

    if multiplier is None:
        mult = 500
    else:
        mult = multiplier

    if dna:
        dna_prot = 'dna'
        bg_c = {
            'A': 'green',
            'C': 'blue',
            'G': 'black',
            'T': 'red',
            '-': 'grey',
            'X': 'white'
        }

        fg_c = {
            'A': 'black',
            'C': 'black',
            'G': 'black',
            'T': 'black',
            '-': 'black',
            'X': 'white'
        }
    else:
        dna_prot = 'aa'
        bg_c = {
            'K': '#145AFF',
            'R': '#145AFF',
            'H': '#8282D2',
            'E': '#E60A0A',
            'D': '#E60A0A',
            'N': '#00DCDC',
            'Q': '#00DCDC',
            'S': '#FA9600',
            'T': '#FA9600',
            'L': '#0F820F',
            'I': '#0F820F',
            'V': '#0F820F',
            'Y': '#3232AA',
            'F': '#3232AA',
            'W': '#B45AB4',
            'C': '#E6E600',
            'M': '#E6E600',
            'A': '#C8C8C8',
            'G': '#EBEBEB',
            'P': '#DC9682',
            '-': 'grey',
            'X': 'white'
        }

        fg_c = {
            'K': 'black',
            'R': 'black',
            'H': 'black',
            'E': 'black',
            'D': 'black',
            'N': 'black',
            'Q': 'black',
            'S': 'black',
            'T': 'black',
            'L': 'black',
            'I': 'black',
            'V': 'black',
            'Y': 'black',
            'F': 'black',
            'W': 'black',
            'C': 'black',
            'M': 'black',
            'A': 'black',
            'G': 'black',
            'P': 'black',
            '-': 'grey',
            'X': 'white'
        }

    if colours == 3:
        bg_c = None
        fg_c = None

    # outfile3 = str(outfile1.replace(".svg", ".nwk"))

    tstyle = TreeStyle()
    tstyle.force_topology = False
    tstyle.mode = types
    tstyle.scale = scale
    tstyle.min_leaf_separation = 0
    tstyle.optimal_scale_level = 'full'  # 'mid'
    # tstyle.complete_branch_lines_when_necessary = False
    if types == 'c':
        tstyle.root_opening_factor = 0.25

    tstyle.draw_guiding_lines = False
    tstyle.guiding_lines_color = 'slateblue'
    tstyle.show_leaf_name = False
    tstyle.allow_face_overlap = True
    tstyle.show_branch_length = False
    tstyle.show_branch_support = False
    TreeNode(format=0, support=True)
    # tnode = TreeNode()

    if root is not None:
        tree.set_outgroup(root)
    # else:
    #     r = tnode.get_midpoint_outgroup()
    #     print("r", r)
    #     tree.set_outgroup(r)
    time_col = []
    for node in tree.traverse():
        # node.ladderize()
        if node.is_leaf() is True:
            try:
                name = node.name.split("_")
                time = name[field2]
                kind = name[3]
                # print(name)
            except:
                time = 'zero'
                name = node.name
                print("Incorrect name format for ", node.name)

            if size is True:
                try:
                    s = 20 + float(name[field1]) * mult
                except:
                    s = 20
                    print("No frequency information for ", node.name)
            else:
                s = 20

            colour = c_dict[time]
            time_col.append((time, colour))
            nstyle = NodeStyle()
            nstyle["fgcolor"] = colour
            nstyle["size"] = s
            nstyle["hz_line_width"] = 10
            nstyle["vt_line_width"] = 10
            nstyle["hz_line_color"] = colour
            nstyle["vt_line_color"] = 'black'
            nstyle["hz_line_type"] = 0
            nstyle["vt_line_type"] = 0
            node.set_style(nstyle)

            if root is not None and node.name == root:  # place holder in case you want to do something with the root leaf
                print('root is ', node.name)
                # nstyle["shape"] = "square"
                # nstyle["fgcolor"] = "black"
                # nstyle["size"] = s
                # nstyle["shape"] = "circle"
                # node.set_style(nstyle)

            else:
                nstyle["shape"] = "circle"
                node.set_style(nstyle)

            if fasta is not None:
                seq = fasta[str(node.name)]
                seqFace = SequenceFace(seq,
                                       seqtype=dna_prot,
                                       fsize=10,
                                       fg_colors=fg_c,
                                       bg_colors=bg_c,
                                       codon=None,
                                       col_w=40,
                                       alt_col_w=3,
                                       special_col=None,
                                       interactive=True)
                # seqFace = SeqMotifFace(seq=seq, motifs=None, seqtype=dna_prot, gap_format=' ', seq_format='()', scale_factor=20,
                #              height=20, width=50, fgcolor='white', bgcolor='grey', gapcolor='white', )
                # seqFace = SeqMotifFace(seq, seq_format="seq", fgcolor=fg_c, bgcolor=bg_c) #interactive=True

                (tree & node.name).add_face(seqFace, 0, "aligned")

        else:
            nstyle = NodeStyle()
            nstyle["size"] = 0.1
            nstyle["hz_line_width"] = 10
            nstyle["vt_line_width"] = 10
            node.set_style(nstyle)
            continue
    tree.ladderize()
    # tnode.ladderize()
    legendkey = sorted(set(time_col))
    legendkey = [(tp, col) for tp, col in legendkey]
    # legendkey.insert(0, ('Root', 'black'))
    legendkey.append(('', 'white'))

    for tm, clr in legendkey:
        tstyle.legend.add_face(faces.CircleFace(30, clr), column=0)
        tstyle.legend.add_face(faces.TextFace('\t' + tm,
                                              ftype='Arial',
                                              fsize=60,
                                              fgcolor='black',
                                              tight_text=True),
                               column=1)
    if show is True:
        tree.show(tree_style=tstyle)

    tree.render(outfile1, dpi=600, tree_style=tstyle)
예제 #11
0
def plot_tree_stacked_barplot(
        tree_file,
        taxon2value_list_barplot=False,
        header_list=False,  # header stackedbarplots
        taxon2set2value_heatmap=False,
        taxon2label=False,
        header_list2=False,  # header counts columns
        biodb=False,
        column_scale=True,
        general_max=False,
        header_list3=False,
        set2taxon2value_list_simple_barplot=False,
        set2taxon2value_list_simple_barplot_counts=True,
        rotate=False,
        taxon2description=False):
    '''

    taxon2value_list_barplot list of lists:
    [[bar1_part1, bar1_part2,...],[bar2_part1, bar2_part2]]
    valeures de chaque liste transformes en pourcentages

    :param tree_file:
    :param taxon2value_list:
    :param biodb:
    :param exclude_outgroup:
    :param bw_scale:
    :return:
    '''

    if biodb:
        from chlamdb.biosqldb import manipulate_biosqldb
        server, db = manipulate_biosqldb.load_db(biodb)

        taxon2description = manipulate_biosqldb.taxon_id2genome_description(
            server, biodb, filter_names=True)

    t1 = Tree(tree_file)

    # Calculate the midpoint node
    R = t1.get_midpoint_outgroup()
    # and set it as tree outgroup
    t1.set_outgroup(R)

    colors2 = [
        "red", "#FFFF00", "#58FA58", "#819FF7", "#F781F3", "#2E2E2E",
        "#F7F8E0", 'black'
    ]
    colors = [
        "#7fc97f", "#386cb0", "#fdc086", "#ffffb3", "#fdb462", "#f0027f",
        "#F7F8E0", 'black'
    ]  # fdc086ff 386cb0ff f0027fff

    tss = TreeStyle()
    tss.draw_guiding_lines = True
    tss.guiding_lines_color = "gray"
    tss.show_leaf_name = False
    if column_scale and header_list2:
        import matplotlib.cm as cm
        from matplotlib.colors import rgb2hex
        import matplotlib as mpl
        column2scale = {}
        col_n = 0
        for column in header_list2:
            values = taxon2set2value_heatmap[column].values()
            #print values
            if min(values) == max(values):
                min_val = 0
                max_val = 1.5 * max(values)
            else:
                min_val = min(values)
                max_val = max(values)
            #print 'min-max', min_val, max_val
            norm = mpl.colors.Normalize(vmin=min_val, vmax=max_val)  # *1.1
            if col_n < 4:
                cmap = cm.OrRd  #
            else:
                cmap = cm.YlGnBu  #PuBu#OrRd

            m = cm.ScalarMappable(norm=norm, cmap=cmap)

            column2scale[column] = [m, float(max_val)]  # *0.7
            col_n += 1

    for i, lf in enumerate(t1.iter_leaves()):

        #if taxon2description[lf.name] == 'Pirellula staleyi DSM 6068':
        #    lf.name = 'Pirellula staleyi DSM 6068'
        #    continue
        if i == 0:

            if taxon2label:
                n = TextFace('  ')
                n.margin_top = 1
                n.margin_right = 1
                n.margin_left = 20
                n.margin_bottom = 1
                n.hz_align = 2
                n.vt_align = 2
                n.rotation = 270
                n.inner_background.color = "white"
                n.opacity = 1.

                tss.aligned_header.add_face(n, 0)
                col_add = 1
            else:
                col_add = 1
            if header_list:
                for col, header in enumerate(header_list):

                    n = TextFace('%s' % (header))
                    n.margin_top = 0
                    n.margin_right = 1
                    n.margin_left = 20
                    n.margin_bottom = 1
                    n.rotation = 270
                    n.hz_align = 2
                    n.vt_align = 2
                    n.inner_background.color = "white"
                    n.opacity = 1.
                    tss.aligned_header.add_face(n, col + col_add)
                col_add += col + 1

            if header_list3:
                #print 'header_list 3!'
                col_tmp = 0
                for header in header_list3:
                    n = TextFace('%s' % (header))
                    n.margin_top = 1
                    n.margin_right = 1
                    n.margin_left = 20
                    n.margin_bottom = 1
                    n.rotation = 270
                    n.hz_align = 2
                    n.vt_align = 2
                    n.inner_background.color = "white"
                    n.opacity = 1.

                    if set2taxon2value_list_simple_barplot_counts:
                        if col_tmp == 0:
                            col_tmp += 1
                        tss.aligned_header.add_face(n, col_tmp + 1 + col_add)
                        n = TextFace('       ')
                        tss.aligned_header.add_face(n, col_tmp + col_add)
                        col_tmp += 2
                    else:
                        tss.aligned_header.add_face(n, col_tmp + col_add)
                        col_tmp += 1
                if set2taxon2value_list_simple_barplot_counts:
                    col_add += col_tmp
                else:
                    col_add += col_tmp

            if header_list2:
                for col, header in enumerate(header_list2):
                    n = TextFace('%s' % (header))
                    n.margin_top = 1
                    n.margin_right = 1
                    n.margin_left = 20
                    n.margin_bottom = 1
                    n.rotation = 270
                    n.hz_align = 2
                    n.vt_align = 2
                    n.inner_background.color = "white"
                    n.opacity = 1.
                    tss.aligned_header.add_face(n, col + col_add)
                col_add += col + 1

        if taxon2label:
            try:
                n = TextFace('%s' % taxon2label[lf.name])
            except:
                try:
                    n = TextFace('%s' % taxon2label[int(lf.name)])
                except:
                    n = TextFace('-')
            n.margin_top = 1
            n.margin_right = 1
            n.margin_left = 20
            n.margin_bottom = 1
            n.inner_background.color = "white"
            n.opacity = 1.
            if rotate:
                n.rotation = 270
            lf.add_face(n, 1, position="aligned")
            col_add = 2
        else:
            col_add = 2

        if taxon2value_list_barplot:

            try:
                val_list_of_lists = taxon2value_list_barplot[lf.name]
            except:
                val_list_of_lists = taxon2value_list_barplot[int(lf.name)]

            #col_count = 0
            for col, value_list in enumerate(val_list_of_lists):

                total = float(sum(value_list))
                percentages = [(i / total) * 100 for i in value_list]
                if col % 3 == 0:
                    col_list = colors2
                else:
                    col_list = colors
                b = StackedBarFace(percentages,
                                   width=150,
                                   height=18,
                                   colors=col_list[0:len(percentages)])
                b.rotation = 0
                b.inner_border.color = "white"
                b.inner_border.width = 0
                b.margin_right = 5
                b.margin_left = 5
                if rotate:
                    b.rotation = 270
                lf.add_face(b, col + col_add, position="aligned")
                #col_count+=1

            col_add += col + 1

        if set2taxon2value_list_simple_barplot:
            col_list = [
                '#fc8d59', '#91bfdb', '#99d594', '#c51b7d', '#f1a340',
                '#999999'
            ]
            color_i = 0
            col = 0
            for one_set in header_list3:
                if color_i > 5:
                    color_i = 0
                color = col_list[color_i]
                color_i += 1
                # values for all taxons
                values_lists = [
                    float(i) for i in
                    set2taxon2value_list_simple_barplot[one_set].values()
                ]
                #print values_lists
                #print one_set
                value = set2taxon2value_list_simple_barplot[one_set][lf.name]

                if set2taxon2value_list_simple_barplot_counts:
                    if isinstance(value, float):
                        a = TextFace(" %s " % str(round(value, 2)))
                    else:
                        a = TextFace(" %s " % str(value))
                    a.margin_top = 1
                    a.margin_right = 2
                    a.margin_left = 5
                    a.margin_bottom = 1
                    if rotate:
                        a.rotation = 270
                    lf.add_face(a, col + col_add, position="aligned")

                #print 'value and max', value, max(values_lists)
                fraction_biggest = (float(value) / max(values_lists)) * 100
                fraction_rest = 100 - fraction_biggest

                #print 'fractions', fraction_biggest, fraction_rest
                b = StackedBarFace([fraction_biggest, fraction_rest],
                                   width=100,
                                   height=15,
                                   colors=[color, 'white'])
                b.rotation = 0
                b.inner_border.color = "grey"
                b.inner_border.width = 0
                b.margin_right = 15
                b.margin_left = 0
                if rotate:
                    b.rotation = 270
                if set2taxon2value_list_simple_barplot_counts:
                    if col == 0:
                        col += 1
                    lf.add_face(b, col + 1 + col_add, position="aligned")
                    col += 2
                else:
                    lf.add_face(b, col + col_add, position="aligned")
                    col += 1
            if set2taxon2value_list_simple_barplot_counts:
                col_add += col

            else:
                col_add += col

        if taxon2set2value_heatmap:
            i = 0
            #if not taxon2label:
            #    col_add-=1
            for col2, head in enumerate(header_list2):

                col_name = header_list2[i]
                try:
                    value = taxon2set2value_heatmap[col_name][str(lf.name)]
                except:
                    try:
                        value = taxon2set2value_heatmap[col_name][round(
                            float(lf.name), 2)]
                    except:
                        value = 0
                if header_list2[i] == 'duplicates':
                    print('dupli', lf.name, value)
                #print 'val----------------', value
                if int(value) > 0:
                    if int(value) >= 10 and int(value) < 100:
                        n = TextFace('%4i' % value)
                    elif int(value) >= 100:
                        n = TextFace('%3i' % value)
                    else:

                        n = TextFace('%5i' % value)

                    n.margin_top = 1
                    n.margin_right = 2
                    n.margin_left = 5
                    n.margin_bottom = 1
                    n.hz_align = 1
                    n.vt_align = 1
                    if rotate:
                        n.rotation = 270
                    n.inner_background.color = rgb2hex(
                        column2scale[col_name][0].to_rgba(
                            float(value)))  #"orange"
                    #print 'xaxaxaxaxa', value,
                    if float(value) > column2scale[col_name][1]:

                        n.fgcolor = 'white'
                    n.opacity = 1.
                    n.hz_align = 1
                    n.vt_align = 1
                    lf.add_face(n, col2 + col_add, position="aligned")
                    i += 1
                else:
                    n = TextFace('')
                    n.margin_top = 1
                    n.margin_right = 1
                    n.margin_left = 5
                    n.margin_bottom = 1
                    n.inner_background.color = "white"
                    n.opacity = 1.
                    if rotate:
                        n.rotation = 270
                    lf.add_face(n, col2 + col_add, position="aligned")
                    i += 1

        #lf.name = taxon2description[lf.name]
        n = TextFace(taxon2description[lf.name],
                     fgcolor="black",
                     fsize=12,
                     fstyle='italic')
        lf.add_face(n, 0)

    for n in t1.traverse():
        nstyle = NodeStyle()

        if n.support < 1:
            nstyle["fgcolor"] = "black"
            nstyle["size"] = 6
            n.set_style(nstyle)
        else:
            nstyle["fgcolor"] = "red"
            nstyle["size"] = 0
            n.set_style(nstyle)

    return t1, tss
예제 #12
0
def plot_heat_tree(tree_file,
                   biodb="chlamydia_04_16",
                   exclude_outgroup=False,
                   bw_scale=True):
    from chlamdb.biosqldb import manipulate_biosqldb
    import matplotlib.cm as cm
    from matplotlib.colors import rgb2hex
    import matplotlib as mpl

    server, db = manipulate_biosqldb.load_db(biodb)

    sql_biodatabase_id = 'select biodatabase_id from biodatabase where name="%s"' % biodb
    db_id = server.adaptor.execute_and_fetchall(sql_biodatabase_id, )[0][0]
    if type(tree_file) == str:
        t1 = Tree(tree_file)
        try:
            R = t1.get_midpoint_outgroup()
            #print 'root', R
            # and set it as tree outgroup
            t1.set_outgroup(R)
        except:
            pass
    elif isinstance(tree_file, Tree):
        t1 = tree_file
    else:
        IOError('Unkown tree format')
    tss = TreeStyle()
    tss.draw_guiding_lines = True
    tss.guiding_lines_color = "gray"
    tss.show_leaf_name = False

    #print "tree", t1

    sql1 = 'select taxon_id, description from bioentry where biodatabase_id=%s and description not like "%%%%plasmid%%%%"' % db_id
    sql2 = 'select t2.taxon_id, t1.GC from genomes_info_%s as t1 inner join bioentry as t2 ' \
           ' on t1.accession=t2.accession where t2.biodatabase_id=%s and t1.description not like "%%%%plasmid%%%%";' % (biodb, db_id)
    sql3 = 'select t2.taxon_id, t1.genome_size from genomes_info_%s as t1 ' \
           ' inner join bioentry as t2 on t1.accession=t2.accession ' \
           ' where t2.biodatabase_id=%s and t1.description not like "%%%%plasmid%%%%";' % (biodb, db_id)
    sql4 = 'select t2.taxon_id,percent_non_coding from genomes_info_%s as t1 ' \
           ' inner join bioentry as t2 on t1.accession=t2.accession ' \
           ' where t2.biodatabase_id=%s and t1.description not like "%%%%plasmid%%%%";' % (biodb, db_id)

    sql_checkm_completeness = 'select taxon_id, completeness from custom_tables.checkm_%s;' % biodb
    sql_checkm_contamination = 'select taxon_id,contamination from custom_tables.checkm_%s;' % biodb

    try:
        taxon_id2completeness = manipulate_biosqldb.to_dict(
            server.adaptor.execute_and_fetchall(sql_checkm_completeness))
        taxon_id2contamination = manipulate_biosqldb.to_dict(
            server.adaptor.execute_and_fetchall(sql_checkm_contamination))
    except:
        taxon_id2completeness = False
    #taxon2description = manipulate_biosqldb.to_dict(server.adaptor.execute_and_fetchall(sql1,))

    taxon2description = manipulate_biosqldb.taxon_id2genome_description(
        server, biodb, filter_names=True)

    taxon2gc = manipulate_biosqldb.to_dict(
        server.adaptor.execute_and_fetchall(sql2, ))
    taxon2genome_size = manipulate_biosqldb.to_dict(
        server.adaptor.execute_and_fetchall(sql3, ))
    taxon2coding_density = manipulate_biosqldb.to_dict(
        server.adaptor.execute_and_fetchall(sql4, ))

    my_taxons = [lf.name for lf in t1.iter_leaves()]

    # Calculate the midpoint node

    if exclude_outgroup:
        excluded = str(list(t1.iter_leaves())[0].name)
        my_taxons.pop(my_taxons.index(excluded))

    genome_sizes = [float(taxon2genome_size[i]) for i in my_taxons]
    gc_list = [float(taxon2gc[i]) for i in my_taxons]
    fraction_list = [float(taxon2coding_density[i]) for i in my_taxons]

    value = 1

    max_genome_size = max(genome_sizes)  #3424182#
    max_gc = max(gc_list)  #48.23

    cmap = cm.YlGnBu  #YlOrRd#OrRd

    norm = mpl.colors.Normalize(vmin=min(genome_sizes) - 100000,
                                vmax=max(genome_sizes))
    m1 = cm.ScalarMappable(norm=norm, cmap=cmap)
    norm = mpl.colors.Normalize(vmin=min(gc_list), vmax=max(gc_list))
    m2 = cm.ScalarMappable(norm=norm, cmap=cmap)
    norm = mpl.colors.Normalize(vmin=min(fraction_list),
                                vmax=max(fraction_list))
    m3 = cm.ScalarMappable(norm=norm, cmap=cmap)

    for i, lf in enumerate(t1.iter_leaves()):
        #if taxon2description[lf.name] == 'Pirellula staleyi DSM 6068':
        #    lf.name = 'Pirellula staleyi DSM 6068'
        #    continue
        if i == 0:
            n = TextFace('Size (Mbp)')
            n.rotation = -25
            n.margin_top = 1
            n.margin_right = 1
            n.margin_left = 20
            n.margin_bottom = 1
            n.inner_background.color = "white"
            n.opacity = 1.
            #lf.add_face(n, 3, position="aligned")
            tss.aligned_header.add_face(n, 3)
            n = TextFace('GC (%)')
            n.rotation = -25
            n.margin_top = 1
            n.margin_right = 1
            n.margin_left = 20
            n.margin_bottom = 1
            n.inner_background.color = "white"
            n.opacity = 1.
            #lf.add_face(n, 5, position="aligned")
            tss.aligned_header.add_face(n, 5)
            n = TextFace('')
            #lf.add_face(n, 2, position="aligned")
            tss.aligned_header.add_face(n, 2)
            #lf.add_face(n, 4, position="aligned")
            tss.aligned_header.add_face(n, 4)
            n = TextFace('Non coding (%)')
            n.margin_top = 1
            n.margin_right = 1
            n.margin_left = 20
            n.margin_bottom = 1
            n.inner_background.color = "white"
            n.opacity = 1.
            n.rotation = -25
            #lf.add_face(n, 7, position="aligned")
            tss.aligned_header.add_face(n, 7)
            n = TextFace('')
            #lf.add_face(n, 6, position="aligned")
            tss.aligned_header.add_face(n, 6)

            if taxon_id2completeness:
                n = TextFace('Completeness (%)')
                n.margin_top = 1
                n.margin_right = 1
                n.margin_left = 20
                n.margin_bottom = 1
                n.inner_background.color = "white"
                n.opacity = 1.
                n.rotation = -25
                #lf.add_face(n, 7, position="aligned")
                tss.aligned_header.add_face(n, 9)
                n = TextFace('')
                #lf.add_face(n, 6, position="aligned")
                tss.aligned_header.add_face(n, 8)

                n = TextFace('Contamination (%)')
                n.margin_top = 1
                n.margin_right = 1
                n.margin_left = 20
                n.margin_bottom = 1
                n.inner_background.color = "white"
                n.opacity = 1.
                n.rotation = -25
                #lf.add_face(n, 7, position="aligned")
                tss.aligned_header.add_face(n, 11)
                n = TextFace('')
                #lf.add_face(n, 6, position="aligned")
                tss.aligned_header.add_face(n, 10)

        value += 1

        #print '------ %s' % lf.name
        if exclude_outgroup and i == 0:
            lf.name = taxon2description[lf.name]
            #print '#######################'
            continue

        n = TextFace(
            '  %s ' %
            str(round(taxon2genome_size[lf.name] / float(1000000), 2)))
        n.margin_top = 1
        n.margin_right = 1
        n.margin_left = 0
        n.margin_bottom = 1
        n.fsize = 7
        n.inner_background.color = "white"
        n.opacity = 1.

        lf.add_face(n, 2, position="aligned")
        #if max_genome_size > 3424182:
        #    max_genome_size = 3424182
        fraction_biggest = (float(taxon2genome_size[lf.name]) /
                            max_genome_size) * 100
        fraction_rest = 100 - fraction_biggest
        if taxon2description[lf.name] == 'Rhabdochlamydia helveticae T3358':
            col = '#fc8d59'
        else:
            if not bw_scale:
                col = rgb2hex(m1.to_rgba(float(
                    taxon2genome_size[lf.name])))  # 'grey'
            else:
                col = '#fc8d59'

        b = StackedBarFace([fraction_biggest, fraction_rest],
                           width=100,
                           height=9,
                           colors=[col, 'white'])
        b.rotation = 0
        b.inner_border.color = "black"
        b.inner_border.width = 0
        b.margin_right = 15
        b.margin_left = 0
        lf.add_face(b, 3, position="aligned")

        fraction_biggest = (float(taxon2gc[lf.name]) / max_gc) * 100
        fraction_rest = 100 - fraction_biggest
        if taxon2description[lf.name] == 'Rhabdochlamydia helveticae T3358':
            col = '#91bfdb'
        else:
            if not bw_scale:
                col = rgb2hex(m2.to_rgba(float(taxon2gc[lf.name])))
            else:
                col = '#91bfdb'
        b = StackedBarFace([fraction_biggest, fraction_rest],
                           width=100,
                           height=9,
                           colors=[col, 'white'])
        b.rotation = 0
        b.inner_border.color = "black"
        b.inner_border.width = 0
        b.margin_left = 0
        b.margin_right = 15

        lf.add_face(b, 5, position="aligned")
        n = TextFace('  %s ' % str(round(float(taxon2gc[lf.name]), 2)))
        n.margin_top = 1
        n.margin_right = 0
        n.margin_left = 0
        n.margin_bottom = 1
        n.fsize = 7
        n.inner_background.color = "white"
        n.opacity = 1.
        lf.add_face(n, 4, position="aligned")

        if taxon2description[lf.name] == 'Rhabdochlamydia helveticae T3358':
            col = '#99d594'
        else:
            if not bw_scale:
                col = rgb2hex(m3.to_rgba(float(taxon2coding_density[lf.name])))
            else:
                col = '#99d594'
        n = TextFace('  %s ' % str(float(taxon2coding_density[lf.name])))
        n.margin_top = 1
        n.margin_right = 0
        n.margin_left = 0
        n.margin_right = 0
        n.margin_bottom = 1
        n.fsize = 7
        n.inner_background.color = "white"
        n.opacity = 1.
        lf.add_face(n, 6, position="aligned")
        fraction = (float(taxon2coding_density[lf.name]) /
                    max(taxon2coding_density.values())) * 100
        fraction_rest = ((max(taxon2coding_density.values()) -
                          taxon2coding_density[lf.name]) /
                         float(max(taxon2coding_density.values()))) * 100
        #print 'fraction, rest', fraction, fraction_rest
        b = StackedBarFace(
            [fraction, fraction_rest],
            width=100,
            height=9,
            colors=[col, 'white'
                    ])  # 1-round(float(taxon2coding_density[lf.name]), 2)
        b.rotation = 0
        b.margin_right = 1
        b.inner_border.color = "black"
        b.inner_border.width = 0
        b.margin_left = 5
        lf.add_face(b, 7, position="aligned")

        if taxon_id2completeness:
            n = TextFace('  %s ' % str(float(taxon_id2completeness[lf.name])))
            n.margin_top = 1
            n.margin_right = 0
            n.margin_left = 0
            n.margin_right = 0
            n.margin_bottom = 1
            n.fsize = 7
            n.inner_background.color = "white"
            n.opacity = 1.
            lf.add_face(n, 8, position="aligned")
            fraction = float(taxon_id2completeness[lf.name])
            fraction_rest = 100 - fraction
            #print 'fraction, rest', fraction, fraction_rest
            b = StackedBarFace(
                [fraction, fraction_rest],
                width=100,
                height=9,
                colors=["#d7191c", 'white'
                        ])  # 1-round(float(taxon2coding_density[lf.name]), 2)
            b.rotation = 0
            b.margin_right = 1
            b.inner_border.color = "black"
            b.inner_border.width = 0
            b.margin_left = 5
            lf.add_face(b, 9, position="aligned")

            n = TextFace('  %s ' % str(float(taxon_id2contamination[lf.name])))
            n.margin_top = 1
            n.margin_right = 0
            n.margin_left = 0
            n.margin_right = 0
            n.margin_bottom = 1
            n.fsize = 7
            n.inner_background.color = "white"
            n.opacity = 1.
            lf.add_face(n, 10, position="aligned")
            fraction = float(taxon_id2contamination[lf.name])
            fraction_rest = 100 - fraction
            #print 'fraction, rest', fraction, fraction_rest
            b = StackedBarFace(
                [fraction, fraction_rest],
                width=100,
                height=9,
                colors=["black", 'white'
                        ])  # 1-round(float(taxon2coding_density[lf.name]), 2)
            b.rotation = 0
            b.margin_right = 1
            b.inner_border.color = "black"
            b.inner_border.width = 0
            b.margin_left = 5
            lf.add_face(b, 11, position="aligned")

            #lf.name = taxon2description[lf.name]
        n = TextFace(taxon2description[lf.name],
                     fgcolor="black",
                     fsize=9,
                     fstyle='italic')
        n.margin_right = 30
        lf.add_face(n, 0)

    for n in t1.traverse():
        nstyle = NodeStyle()
        if n.support < 1:
            nstyle["fgcolor"] = "black"
            nstyle["size"] = 6
            n.set_style(nstyle)
        else:
            nstyle["fgcolor"] = "red"
            nstyle["size"] = 0
            n.set_style(nstyle)

    return t1, tss
예제 #13
0
def plot_tree_text_metadata(tree_file, header2taxon2text, ordered_header_list,
                            biodb):

    from chlamdb.biosqldb import manipulate_biosqldb
    server, db = manipulate_biosqldb.load_db(biodb)

    t1 = Tree(tree_file)

    taxon2description = manipulate_biosqldb.taxon_id2genome_description(
        server, biodb, filter_names=True)

    # Calculate the midpoint node
    R = t1.get_midpoint_outgroup()
    # and set it as tree outgroup
    t1.set_outgroup(R)
    tss = TreeStyle()
    tss.draw_guiding_lines = True
    tss.guiding_lines_color = "gray"
    tss.show_leaf_name = False

    for i, leaf in enumerate(t1.iter_leaves()):

        # first leaf, add headers
        if i == 0:
            for column, header in enumerate(ordered_header_list):

                n = TextFace('%s' % (header))
                n.margin_top = 0
                n.margin_right = 1
                n.margin_left = 20
                n.margin_bottom = 1
                n.rotation = 270
                n.hz_align = 2
                n.vt_align = 2
                n.inner_background.color = "white"
                n.opacity = 1.
                tss.aligned_header.add_face(n, column)
        for column, header in enumerate(ordered_header_list):
            text = header2taxon2text[header][int(leaf.name)]
            n = TextFace('%s' % text)
            n.margin_top = 1
            n.margin_right = 1
            n.margin_left = 5
            n.margin_bottom = 1
            n.inner_background.color = "white"
            n.opacity = 1.
            #n.rotation = 270
            leaf.add_face(n, column + 1, position="aligned")
        # rename leaf (taxon_id => description)
        n = TextFace(taxon2description[leaf.name],
                     fgcolor="black",
                     fsize=12,
                     fstyle='italic')
        leaf.add_face(n, 0)

    for n in t1.traverse():
        # rename leaf

        nstyle = NodeStyle()

        if n.support < 1:
            nstyle["fgcolor"] = "black"
            nstyle["size"] = 6
            n.set_style(nstyle)
        else:
            nstyle["fgcolor"] = "red"
            nstyle["size"] = 0
            n.set_style(nstyle)

    return t1, tss
예제 #14
0
# T.dist = 0.0 # set germline distance to 0

# T.write(format=1, outfile=tree_file_name+".multifurc")

ts = TreeStyle()
# ts.mode = "c"
ts.scale = 500
ts.optimal_scale_level = "full"
# ts.arc_start = 180 # -180
# ts.arc_span = 180 # 359
ts.show_leaf_name = False
ts.show_branch_length = False
ts.show_branch_support = False
# ts.root_opening_factor = 0.75
ts.draw_guiding_lines = False
ts.margin_left = 50
ts.margin_right = 50
ts.margin_top = 50
ts.margin_bottom = 50
ts.rotation = 0

path_to_sequence_string_uid_to_isotype_map = "/Users/lime/Dropbox/quake/Bcell/selection/figures/treePlots/v2/Bcell_flu_high_res.sequences.isotypeDict.V6_Full.csv"
sequence_string_uid_to_isotype = {}
with open(path_to_sequence_string_uid_to_isotype_map) as f:
    for line in f:
        vals = line.rstrip().split()
        sequence_string_uid_to_isotype[vals[1]] = vals[2]

path_to_isotype_to_color_map = "/Users/lime/Dropbox/quake/Bcell/selection/figures/treePlots/v2/isotype_to_color_dict.json"
with open(path_to_isotype_to_color_map, 'rU') as f:
예제 #15
0
def plot_heatmap_tree_locus(biodb,
                            tree_file,
                            taxid2count,
                            taxid2identity=False,
                            taxid2locus=False,
                            reference_taxon=False,
                            n_paralogs_barplot=False):
    '''

    plot tree and associated heatmap with count of homolgs
    optional:
        - add identity of closest homolog
        - add locus tag of closest homolog

    '''

    from chlamdb.biosqldb import manipulate_biosqldb

    server, db = manipulate_biosqldb.load_db(biodb)

    taxid2organism = manipulate_biosqldb.taxon_id2genome_description(
        server, biodb, True)

    t1 = Tree(tree_file)
    ts = TreeStyle()
    ts.draw_guiding_lines = True
    ts.guiding_lines_color = "gray"
    # Calculate the midpoint node
    R = t1.get_midpoint_outgroup()
    # and set it as tree outgroup
    t1.set_outgroup(R)

    leaf_number = 0

    for lf in t1.iter_leaves():

        if str(lf.name) not in taxid2count:
            taxid2count[str(lf.name)] = 0

    max_count = max([taxid2count[str(lf.name)] for lf in t1.iter_leaves()])

    for i, lf in enumerate(t1.iter_leaves()):

        # top leaf, add header
        if i == 0:

            n = TextFace('Number of homologs')
            n.margin_top = 1
            n.margin_right = 1
            n.margin_left = 20
            n.margin_bottom = 1
            n.inner_background.color = "white"
            n.opacity = 1.
            n.rotation = -25
            #lf.add_face(n, 7, position="aligned")
            ts.aligned_header.add_face(n, 1)

            if taxid2identity:
                n = TextFace('Protein identity')
                n.margin_top = 1
                n.margin_right = 1
                n.margin_left = 20
                n.margin_bottom = 1
                n.inner_background.color = "white"
                n.opacity = 1.
                n.rotation = -25
                #lf.add_face(n, 7, position="aligned")
                ts.aligned_header.add_face(n, 2)
            if taxid2locus:
                n = TextFace('Locus tag')
                n.margin_top = 1
                n.margin_right = 1
                n.margin_left = 20
                n.margin_bottom = 1
                n.inner_background.color = "white"
                n.opacity = 1.
                n.rotation = -25
                #lf.add_face(n, 7, position="aligned")
                ts.aligned_header.add_face(n, 3)

        leaf_number += 1

        lf.branch_vertical_margin = 0

        data = [taxid2count[str(lf.name)]]

        # possibility to add one or more columns
        for col, value in enumerate(data):
            col_index = col
            if value > 0:
                n = TextFace(' %s ' % str(value))
                n.margin_top = 2

                n.margin_right = 2
                if col == 0:
                    n.margin_left = 20
                else:
                    n.margin_left = 2
                n.margin_bottom = 2
                n.inner_background.color = "white"  # #81BEF7
                n.opacity = 1.
                lf.add_face(n, col, position="aligned")

            else:
                n = TextFace(' %s ' % str(value))
                n.margin_top = 2
                n.margin_right = 2
                if col == 0:
                    n.margin_left = 20
                else:
                    n.margin_left = 2
                n.margin_bottom = 2
                n.inner_background.color = "white"
                n.opacity = 1.
                lf.add_face(n, col, position="aligned")
        # optionally indicate number of paralogs as a barplot
        if n_paralogs_barplot:
            col_index += 1
            percent = (float(value) / max_count) * 100
            n = StackedBarFace([percent, 100 - percent],
                               width=150,
                               height=18,
                               colors=['#6699ff', 'white'],
                               line_color='white')
            n.rotation = 0
            n.inner_border.color = "white"
            n.inner_border.width = 0
            n.margin_right = 15
            n.margin_left = 0
            lf.add_face(n, col + 1, position="aligned")

        # optionally add additionnal column with identity
        if taxid2identity:
            import matplotlib.cm as cm
            from matplotlib.colors import rgb2hex
            import matplotlib as mpl

            norm = mpl.colors.Normalize(vmin=0, vmax=100)
            cmap = cm.OrRd
            m = cm.ScalarMappable(norm=norm, cmap=cmap)

            try:
                if round(taxid2identity[str(lf.name)], 2) != 100:
                    value = "%.2f" % round(taxid2identity[str(lf.name)], 2)
                else:
                    value = "%.1f" % round(taxid2identity[str(lf.name)], 2)
            except:
                value = '-'
            if str(lf.name) == str(reference_taxon):
                value = '         '
            n = TextFace(' %s ' % value)
            n.margin_top = 2
            n.margin_right = 2
            n.margin_left = 20
            n.margin_bottom = 2
            if not value.isspace() and value is not '-':
                n.inner_background.color = rgb2hex(m.to_rgba(float(value)))
                if float(value) > 82:
                    n.fgcolor = 'white'
            n.opacity = 1.
            if str(lf.name) == str(reference_taxon):
                n.inner_background.color = '#800000'

            lf.add_face(n, col_index + 1, position="aligned")
        # optionaly add column with locus name
        if taxid2locus:
            try:
                value = str(taxid2locus[str(lf.name)])
            except:
                value = '-'
            n = TextFace(' %s ' % value)
            n.margin_top = 2
            n.margin_right = 2
            n.margin_left = 2
            n.margin_bottom = 2
            if str(lf.name) != str(reference_taxon):
                n.inner_background.color = "white"
            else:
                n.fgcolor = '#ff0000'
                n.inner_background.color = "white"
            n.opacity = 1.
            lf.add_face(n, col_index + 2, position="aligned")
        lf.name = taxid2organism[str(lf.name)]

    return t1, leaf_number, ts
예제 #16
0
def draw_tree(tree, conf, outfile):
    try:
        from ete3 import (add_face_to_node, AttrFace, TextFace, TreeStyle,
                          RectFace, CircleFace, SequenceFace, random_color,
                          SeqMotifFace)
    except ImportError as e:
        print(e)
        return

    def ly_basic(node):
        if node.is_leaf():
            node.img_style['size'] = 0
        else:
            node.img_style['size'] = 0
            node.img_style['shape'] = 'square'
            if len(MIXED_RES) > 1 and hasattr(node, "tree_seqtype"):
                if node.tree_seqtype == "nt":
                    node.img_style["bgcolor"] = "#CFE6CA"
                    ntF = TextFace("nt",
                                   fsize=6,
                                   fgcolor='#444',
                                   ftype='Helvetica')
                    add_face_to_node(ntF, node, 10, position="branch-bottom")
            if len(NPR_TREES) > 1 and hasattr(node, "tree_type"):
                node.img_style['size'] = 4
                node.img_style['fgcolor'] = "steelblue"

        node.img_style['hz_line_width'] = 1
        node.img_style['vt_line_width'] = 1

    def ly_leaf_names(node):
        if node.is_leaf():
            spF = TextFace(node.species,
                           fsize=10,
                           fgcolor='#444444',
                           fstyle='italic',
                           ftype='Helvetica')
            add_face_to_node(spF, node, column=0, position='branch-right')
            if hasattr(node, 'genename'):
                geneF = TextFace(" (%s)" % node.genename,
                                 fsize=8,
                                 fgcolor='#777777',
                                 ftype='Helvetica')
                add_face_to_node(geneF,
                                 node,
                                 column=1,
                                 position='branch-right')

    def ly_supports(node):
        if not node.is_leaf() and node.up:
            supFace = TextFace("%0.2g" % (node.support),
                               fsize=7,
                               fgcolor='indianred')
            add_face_to_node(supFace, node, column=0, position='branch-top')

    def ly_tax_labels(node):
        if node.is_leaf():
            c = LABEL_START_COL
            largest = 0
            for tname in TRACKED_CLADES:
                if hasattr(node,
                           "named_lineage") and tname in node.named_lineage:
                    linF = TextFace(tname, fsize=10, fgcolor='white')
                    linF.margin_left = 3
                    linF.margin_right = 2
                    linF.background.color = lin2color[tname]
                    add_face_to_node(linF, node, c, position='aligned')
                    c += 1

            for n in range(c, len(TRACKED_CLADES)):
                add_face_to_node(TextFace('', fsize=10, fgcolor='slategrey'),
                                 node,
                                 c,
                                 position='aligned')
                c += 1

    def ly_full_alg(node):
        pass

    def ly_block_alg(node):
        if node.is_leaf():
            if 'sequence' in node.features:
                seqFace = SeqMotifFace(node.sequence, [])
                # [10, 100, "[]", None, 10, "black", "rgradient:blue", "arial|8|white|domain Name"],
                motifs = []
                last_lt = None
                for c, lt in enumerate(node.sequence):
                    if lt != '-':
                        if last_lt is None:
                            last_lt = c
                        if c + 1 == len(node.sequence):
                            start, end = last_lt, c
                            motifs.append([
                                start, end, "()", 0, 12, "slategrey",
                                "slategrey", None
                            ])
                            last_lt = None
                    elif lt == '-':
                        if last_lt is not None:
                            start, end = last_lt, c - 1
                            motifs.append([
                                start, end, "()", 0, 12, "grey", "slategrey",
                                None
                            ])
                            last_lt = None

                seqFace = SeqMotifFace(node.sequence,
                                       motifs,
                                       intermotif_format="line",
                                       seqtail_format="line",
                                       scale_factor=ALG_SCALE)
                add_face_to_node(seqFace, node, ALG_START_COL, aligned=True)

    TRACKED_CLADES = [
        "Eukaryota",
        "Viridiplantae",
        "Fungi",
        "Alveolata",
        "Metazoa",
        "Stramenopiles",
        "Rhodophyta",
        "Amoebozoa",
        "Crypthophyta",
        "Bacteria",
        "Alphaproteobacteria",
        "Betaproteobacteria",
        "Cyanobacteria",
        "Gammaproteobacteria",
    ]

    # ["Opisthokonta",  "Apicomplexa"]

    colors = random_color(num=len(TRACKED_CLADES), s=0.45)
    lin2color = dict([(ln, colors[i]) for i, ln in enumerate(TRACKED_CLADES)])

    NAME_FACE = AttrFace('name', fsize=10, fgcolor='#444444')

    LABEL_START_COL = 10
    ALG_START_COL = 40
    ts = TreeStyle()
    ts.draw_aligned_faces_as_table = False
    ts.draw_guiding_lines = False
    ts.show_leaf_name = False
    ts.show_branch_support = False
    ts.scale = 160

    ts.layout_fn = [ly_basic, ly_leaf_names, ly_supports, ly_tax_labels]

    MIXED_RES = set()
    MAX_SEQ_LEN = 0
    NPR_TREES = []
    for n in tree.traverse():
        if hasattr(n, "tree_seqtype"):
            MIXED_RES.add(n.tree_seqtype)
        if hasattr(n, "tree_type"):
            NPR_TREES.append(n.tree_type)
        seq = getattr(n, "sequence", "")
        MAX_SEQ_LEN = max(len(seq), MAX_SEQ_LEN)

    if MAX_SEQ_LEN:
        ALG_SCALE = min(1, 1000. / MAX_SEQ_LEN)
        ts.layout_fn.append(ly_block_alg)

    if len(NPR_TREES) > 1:
        rF = RectFace(4, 4, "steelblue", "steelblue")
        rF.margin_right = 10
        rF.margin_left = 10
        ts.legend.add_face(rF, 0)
        ts.legend.add_face(TextFace(" NPR node"), 1)
        ts.legend_position = 3

    if len(MIXED_RES) > 1:
        rF = RectFace(20, 20, "#CFE6CA", "#CFE6CA")
        rF.margin_right = 10
        rF.margin_left = 10
        ts.legend.add_face(rF, 0)
        ts.legend.add_face(TextFace(" Nucleotide based alignment"), 1)
        ts.legend_position = 3

    try:
        tree.set_species_naming_function(spname)
        annotate_tree_with_ncbi(tree)
        a = tree.search_nodes(species='Dictyostelium discoideum')[0]
        b = tree.search_nodes(species='Chondrus crispus')[0]
        #out = tree.get_common_ancestor([a, b])
        #out = tree.search_nodes(species='Haemophilus parahaemolyticus')[0].up
        tree.set_outgroup(out)
        tree.swap_children()
    except Exception:
        pass

    tree.render(outfile, tree_style=ts, w=170, units='mm', dpi=150)
    tree.render(outfile + '.svg', tree_style=ts, w=170, units='mm', dpi=150)
    tree.render(outfile + '.pdf', tree_style=ts, w=170, units='mm', dpi=150)
예제 #17
0
파일: visualize.py 프로젝트: Ward9250/ete
def draw_tree(tree, conf, outfile):
    try:
        from ete3 import (add_face_to_node, AttrFace, TextFace, TreeStyle, RectFace, CircleFace,
                             SequenceFace, random_color, SeqMotifFace)
    except ImportError as e:
        print(e)
        return

    def ly_basic(node):
        if node.is_leaf():
            node.img_style['size'] = 0
        else:
            node.img_style['size'] = 0
            node.img_style['shape'] = 'square'
            if len(MIXED_RES) > 1 and hasattr(node, "tree_seqtype"):
                if node.tree_seqtype == "nt":
                    node.img_style["bgcolor"] = "#CFE6CA"
                    ntF = TextFace("nt", fsize=6, fgcolor='#444', ftype='Helvetica')
                    add_face_to_node(ntF, node, 10, position="branch-bottom")
            if len(NPR_TREES) > 1 and hasattr(node, "tree_type"):
                node.img_style['size'] = 4
                node.img_style['fgcolor'] = "steelblue"

        node.img_style['hz_line_width'] = 1
        node.img_style['vt_line_width'] = 1

    def ly_leaf_names(node):
        if node.is_leaf():
            spF = TextFace(node.species, fsize=10, fgcolor='#444444', fstyle='italic', ftype='Helvetica')
            add_face_to_node(spF, node, column=0, position='branch-right')
            if hasattr(node, 'genename'):
                geneF = TextFace(" (%s)" %node.genename, fsize=8, fgcolor='#777777', ftype='Helvetica')
                add_face_to_node(geneF, node, column=1, position='branch-right')

    def ly_supports(node):
        if not node.is_leaf() and node.up:
            supFace = TextFace("%0.2g" %(node.support), fsize=7, fgcolor='indianred')
            add_face_to_node(supFace, node, column=0, position='branch-top')

    def ly_tax_labels(node):
        if node.is_leaf():
            c = LABEL_START_COL
            largest = 0
            for tname in TRACKED_CLADES:
                if hasattr(node, "named_lineage") and tname in node.named_lineage:
                    linF = TextFace(tname, fsize=10, fgcolor='white')
                    linF.margin_left = 3
                    linF.margin_right = 2
                    linF.background.color = lin2color[tname]
                    add_face_to_node(linF, node, c, position='aligned')
                    c += 1

            for n in range(c, len(TRACKED_CLADES)):
                add_face_to_node(TextFace('', fsize=10, fgcolor='slategrey'), node, c, position='aligned')
                c+=1

    def ly_full_alg(node):
        pass

    def ly_block_alg(node):
        if node.is_leaf():
            if 'sequence' in node.features:
                seqFace = SeqMotifFace(node.sequence, [])
                # [10, 100, "[]", None, 10, "black", "rgradient:blue", "arial|8|white|domain Name"],
                motifs = []
                last_lt = None
                for c, lt in enumerate(node.sequence):
                    if lt != '-':
                        if last_lt is None:
                            last_lt = c
                        if c+1 == len(node.sequence):
                            start, end = last_lt, c
                            motifs.append([start, end, "()", 0, 12, "slategrey", "slategrey", None])
                            last_lt = None
                    elif lt == '-':
                        if last_lt is not None:
                            start, end = last_lt, c-1
                            motifs.append([start, end, "()", 0, 12, "grey", "slategrey", None])
                            last_lt = None

                seqFace = SeqMotifFace(node.sequence, motifs,
                                       intermotif_format="line",
                                       seqtail_format="line", scale_factor=ALG_SCALE)
                add_face_to_node(seqFace, node, ALG_START_COL, aligned=True)


    TRACKED_CLADES = ["Eukaryota", "Viridiplantae",  "Fungi",
                      "Alveolata", "Metazoa", "Stramenopiles", "Rhodophyta",
                      "Amoebozoa", "Crypthophyta", "Bacteria",
                      "Alphaproteobacteria", "Betaproteobacteria", "Cyanobacteria",
                      "Gammaproteobacteria",]

    # ["Opisthokonta",  "Apicomplexa"]

    colors = random_color(num=len(TRACKED_CLADES), s=0.45)
    lin2color = dict([(ln, colors[i]) for i, ln in enumerate(TRACKED_CLADES)])

    NAME_FACE = AttrFace('name', fsize=10, fgcolor='#444444')

    LABEL_START_COL = 10
    ALG_START_COL = 40
    ts = TreeStyle()
    ts.draw_aligned_faces_as_table = False
    ts.draw_guiding_lines = False
    ts.show_leaf_name = False
    ts.show_branch_support = False
    ts.scale = 160

    ts.layout_fn = [ly_basic, ly_leaf_names, ly_supports, ly_tax_labels]

    MIXED_RES = set()
    MAX_SEQ_LEN = 0
    NPR_TREES = []
    for n in tree.traverse():
        if hasattr(n, "tree_seqtype"):
            MIXED_RES.add(n.tree_seqtype)
        if hasattr(n, "tree_type"):
            NPR_TREES.append(n.tree_type)
        seq = getattr(n, "sequence", "")
        MAX_SEQ_LEN = max(len(seq), MAX_SEQ_LEN)

    if MAX_SEQ_LEN:
        ALG_SCALE = min(1, 1000./MAX_SEQ_LEN)
        ts.layout_fn.append(ly_block_alg)

    if len(NPR_TREES) > 1:
        rF = RectFace(4, 4, "steelblue", "steelblue")
        rF.margin_right = 10
        rF.margin_left = 10
        ts.legend.add_face(rF, 0)
        ts.legend.add_face(TextFace(" NPR node"), 1)
        ts.legend_position = 3

    if len(MIXED_RES) > 1:
        rF = RectFace(20, 20, "#CFE6CA", "#CFE6CA")
        rF.margin_right = 10
        rF.margin_left = 10
        ts.legend.add_face(rF, 0)
        ts.legend.add_face(TextFace(" Nucleotide based alignment"), 1)
        ts.legend_position = 3


    try:
        tree.set_species_naming_function(spname)
        annotate_tree_with_ncbi(tree)
        a = tree.search_nodes(species='Dictyostelium discoideum')[0]
        b = tree.search_nodes(species='Chondrus crispus')[0]
        #out = tree.get_common_ancestor([a, b])
        #out = tree.search_nodes(species='Haemophilus parahaemolyticus')[0].up
        tree.set_outgroup(out)
        tree.swap_children()
    except Exception:
        pass

    tree.render(outfile, tree_style=ts, w=170, units='mm', dpi=150)
    tree.render(outfile+'.svg', tree_style=ts, w=170, units='mm', dpi=150)
    tree.render(outfile+'.pdf', tree_style=ts, w=170, units='mm', dpi=150)
예제 #18
0
def drawtree(sequence_dictionary):
    # LOAD TCS EVALUATION FILE
    all_evaluations = list(SeqIO.parse(EVALUATIONFILE, "fasta"))
    dict_of_evaluations = {}
    for seq_record in all_evaluations:
        dict_of_evaluations[seq_record.id] = seq_record.seq
        print(seq_record.id + " : " + seq_record.seq)

    # LOAD NEWICK FILE TO CHECK FOR REMOVED TAXA AND REMOVE THEM FROM THE DICTIONARY
    with open(SPECIES_TREE, 'r') as newick_file:
        newick_string = newick_file.read()
    for key in list(sequence_dictionary):
        if key in newick_string:
            print(key + " ok!")
        else:
            sequence_dictionary.pop(key, None)

    # LOAD FASTA SEQUENCES
    all_sequences = list(SeqIO.parse(FASTA_ALIGNED_TRIMMED, "fasta"))
    dict_of_sequences = {}
    for seq_record in all_sequences:
        # Only load fasta sequence if it was not removed during tree assembly
        if seq_record.id in newick_string:
            dict_of_sequences[seq_record.id] = seq_record.seq
        print(seq_record.id + " : " + seq_record.seq)

    # AMINO ACID COLORING
    amino_acid_fgcolor_dict = {
        'A': 'Black',
        'C': 'Black',
        'D': 'Black',
        'E': 'Black',
        'F': 'Black',
        'G': 'Black',
        'H': 'Black',
        'I': 'Black',
        'K': 'Black',
        'L': 'Black',
        'M': 'Black',
        'N': 'Black',
        'P': 'Black',
        'Q': 'Black',
        'R': 'Black',
        'S': 'Black',
        'T': 'Black',
        'V': 'Black',
        'W': 'Black',
        'Y': 'Black',
        '-': 'Black'
    }

    # AMINO ACID BACKGROUND COLORING
    amino_acid_bgcolor_dict = {
        'A': 'White',
        'C': 'White',
        'D': 'White',
        'E': 'White',
        'F': 'White',
        'G': 'White',
        'H': 'White',
        'I': 'White',
        'K': 'White',
        'L': 'White',
        'M': 'White',
        'N': 'White',
        'P': 'White',
        'Q': 'White',
        'R': 'White',
        'S': 'White',
        'T': 'White',
        'V': 'White',
        'W': 'White',
        'Y': 'White',
        '-': 'White'
    }

    outgroup_name = 'XP_022098898.1'

    print("Loading tree file " + SPECIES_TREE)
    t = Tree(SPECIES_TREE)
    ts = TreeStyle()
    ts.show_leaf_name = False
    # Zoom in x-axis direction
    ts.scale = 40
    # This makes all branches the same length!!!!!!!
    ts.force_topology = True
    #Tree.render(t, "final_tree_decoded.svg")
    t.set_outgroup(t & outgroup_name)
    ts.show_branch_support = False
    ts.show_branch_length = False
    ts.draw_guiding_lines = True
    ts.branch_vertical_margin = 10  # 10 pixels between adjacent branches
    print(t)

    # Define node styles for different animal classes
    # Mammals
    mammalia = NodeStyle()
    mammalia["bgcolor"] = "Chocolate"
    # Reptiles
    reptilia = NodeStyle()
    reptilia["bgcolor"] = "Olive"
    # Cartilaginous fish
    chondrichthyes = NodeStyle()
    chondrichthyes["bgcolor"] = "SteelBlue"
    # Ray-fenned fish
    actinopterygii = NodeStyle()
    actinopterygii["bgcolor"] = "CornflowerBlue"
    # Lobe-finned fish
    sarcopterygii = NodeStyle()
    sarcopterygii["bgcolor"] = "DarkCyan"
    # Birds
    aves = NodeStyle()
    aves["bgcolor"] = "DarkSalmon"
    # Amphibia
    amphibia = NodeStyle()
    amphibia["bgcolor"] = "DarkSeaGreen"
    # Myxini
    myxini = NodeStyle()
    myxini["bgcolor"] = "LightBlue"

    general_leaf_style = NodeStyle()
    # size of the blue ball
    general_leaf_style["size"] = 15

    # Draws nodes as small red spheres of diameter equal to 10 pixels
    nstyle = NodeStyle()
    nstyle["shape"] = "sphere"
    nstyle["size"] = 15
    nstyle["fgcolor"] = "darkred"
    # Gray dashed branch lines
    #nstyle["hz_line_type"] = 1
    #nstyle["hz_line_color"] = "#cccccc"

    # Applies the same static style to all nodes in the tree if they are not leaves.
    # Note that if "nstyle" is modified, changes will affect to all nodes
    # Apply a separate style to all leaf nodes
    for node in t.traverse():
        if node.is_leaf():
            print("Setting leaf style for node " + node.name)
            animal_class_name = sequence_dictionary[node.name][3]
            print(animal_class_name)
            if animal_class_name == 'Mammalia':
                node.set_style(mammalia)
            elif animal_class_name == 'Reptilia':
                node.set_style(reptilia)
            elif animal_class_name == 'Chondrichthyes':
                node.set_style(chondrichthyes)
            elif animal_class_name == 'Actinopterygii':
                node.set_style(actinopterygii)
            elif animal_class_name == 'Sarcopterygii':
                node.set_style(sarcopterygii)
            elif animal_class_name == 'Aves':
                node.set_style(aves)
            elif animal_class_name == 'Amphibia':
                node.set_style(amphibia)
            elif animal_class_name == 'Myxini':
                node.set_style(myxini)
            else:
                node.set_style(nstyle)
            # Set general leaf attributes
            node.set_style(general_leaf_style)
        else:
            node.set_style(nstyle)

    # ADD IMAGES
    #for key, value in dict_of_images.items():
    #    imgFace = ImgFace(IMG_PATH+value, height = 40)
    #    (t & key).add_face(imgFace, 0, "aligned")
    #    imgFace.margin_right = 10
    #    imgFace.hzalign = 2

    # ADD TEXT
    for key, value in sequence_dictionary.items():
        if key == outgroup_name:
            print(key, value[0], value[1], value[2], value[3])
            textFace = TextFace(value[1] + " (" + value[2] + ", " + key +
                                ")   ",
                                fsize=16)
            (t & key).add_face(textFace, 2, "aligned")
            textFace.margin_left = 10
            print(key, value[0], value[1], value[2], value[3])
            textFace = TextFace(" ", fsize=16)
            (t & key).add_face(textFace, 2, "aligned")
            textFace.margin_left = 10
            print(key, value[0], value[1], value[2], value[3])
            textFace = TextFace("Consensus score", fsize=16)
            (t & key).add_face(textFace, 2, "aligned")
            textFace.margin_left = 10
        else:
            print(key, value[0], value[1], value[2], value[3])
            textFace = TextFace(value[1] + " (" + value[2] + ", " + key +
                                ")   ",
                                fsize=16)
            (t & key).add_face(textFace, 2, "aligned")
            textFace.margin_left = 10

    # ADD DUMMY TEXT
    for key, value in sequence_dictionary.items():
        textFace3 = TextFace("   ", fsize=16)
        (t & key).add_face(textFace3, 0, "aligned")

    # ADD SVG IMAGES
    for key, value in sequence_dictionary.items():
        svgFace = SVGFace(IMG_PATH + value[0], height=40)
        (t & key).add_face(svgFace, 1, "aligned")
        svgFace.margin_right = 10
        svgFace.margin_left = 10
        svgFace.hzalign = 2
        animal_class_name = value[3]
        if animal_class_name == 'Mammalia':
            svgFace.background.color = "Chocolate"
        elif animal_class_name == 'Reptilia':
            svgFace.background.color = "Olive"
        elif animal_class_name == 'Chondrichthyes':
            svgFace.background.color = "SteelBlue"
        elif animal_class_name == 'Actinopterygii':
            svgFace.background.color = "CornflowerBlue"
        elif animal_class_name == 'Sarcopterygii':
            svgFace.background.color = "DarkCyan"
        elif animal_class_name == 'Aves':
            svgFace.background.color = "DarkSalmon"
        elif animal_class_name == 'Amphibia':
            svgFace.background.color = "DarkSeaGreen"
        elif animal_class_name == 'Myxini':
            svgFace.background.color = "LightBlue"
        else:
            svgFace.background.color = "White"

    color_dict = {
        '-': "White",
        '0': "#FF6666",
        '1': "#EE7777",
        '2': "#DD8888",
        '3': "#CC9999",
        '4': "#BBAAAA",
        '5': "#AABBBB",
        '6': "#99CCCC",
        '7': "#88DDDD",
        '8': "#77EEEE",
        '9': "#66FFFF"
    }

    # ADD SEQUENCES AS TEXT (ADDING MORE THAN ONE SEQUENCE DOES NOT WORK)
    for key, value in dict_of_sequences.items():
        if key == outgroup_name:
            for char in range(0, len(value)):
                textFace2 = TextFace(value[char], fsize=16)
                (t & key).add_face(textFace2, char + 3, "aligned")
                textFace2.background.color = color_dict[
                    dict_of_evaluations[key][char]]
                # This is for the consensus coloring, print a space instead of the sequence charactger of the outgroup sequence
                textFace4 = TextFace(' ', fsize=16)
                (t & key).add_face(textFace4, char + 3, "aligned")
                textFace4.background.color = "White"

                textFace5 = TextFace(' ', fsize=16)
                (t & key).add_face(textFace5, char + 3, "aligned")
                textFace5.background.color = color_dict[
                    dict_of_evaluations['cons'][char]]

                textFace6 = TextFace(' ', fsize=16)
                (t & key).add_face(textFace6, char + 3, "aligned")
                textFace6.background.color = "White"
                # margins have the same color as the consensus color
                # These do not work!!!!
                #textFace4.margin_top = 10
                #textFace4.margin_bottom = 10

        else:
            for char in range(0, len(value)):
                #print(str(char) + " of " + str(len(value)))
                textFace2 = TextFace(value[char], fsize=16)
                (t & key).add_face(textFace2, char + 3, "aligned")
                #print('key: '+key)
                #print(dict_of_evaluations[key])
                #print(dict_of_evaluations[key][char])
                textFace2.background.color = color_dict[
                    dict_of_evaluations[key][char]]

    # ROTATING SOME NODES CAN BE DONE HERE:
    # MOVE ACTINOPTERYGII NEXT TO THE OTHER FISH
    # Actinopterygii, e.g. Spotted gar - XP_006632034.2
    # Sarcopterygii, e.g. Coelacant - XP_006006690.1
    # Mammals, e.g. Rat - NP_446105.1
    # Birds, e.g. Chicken - XP_420532.3
    #n1 = t.get_common_ancestor("XP_420532.3", "NP_446105.1")

    # spotted gar and coelacant
    #n1 = t.get_common_ancestor("XP_006632034.2", "XP_006006690.1")
    #n1.swap_children()

    # Atlantic salmon and coelacant
    #n2 = t.get_common_ancestor("NP_001167218.1", "XP_006006690.1")
    #n2.swap_children()

    # Tibet frog and coelacanttroglodytes
    #n2 = t.get_common_ancestor("XP_018419054.1", "XP_006006690.1")
    #n2.swap_children()

    # Xenopus and coelacant
    #n2 = t.get_common_ancestor("XP_002933363.1", "XP_006006690.1")
    #n2.swap_children()

    # alligator and penguin
    #n2 = t.get_common_ancestor("XP_006276984.1", "XP_009329004.1")
    #n2.swap_children()

    # hagfish and human
    n2 = t.get_common_ancestor("ENSEBUT00000000354.1", "NP_005420.1")
    n2.swap_children()

    # zebrafish and human
    n2 = t.get_common_ancestor("NP_991297.1", "NP_005420.1")
    n2.swap_children()

    # Tibet frog and human
    n2 = t.get_common_ancestor("XP_018419054.1", "NP_005420.1")
    n2.swap_children()

    # gekko and penguin
    n2 = t.get_common_ancestor("XP_015283812.1", "XP_009329004.1")
    n2.swap_children()

    # turtle and penguin
    n2 = t.get_common_ancestor("XP_005304228.1", "XP_009329004.1")
    n2.swap_children()

    # Opossum and human
    n2 = t.get_common_ancestor("XP_007496150.2", "NP_005420.1")
    n2.swap_children()

    # Dog and human
    n2 = t.get_common_ancestor("XP_540047.2", "NP_005420.1")
    n2.swap_children()

    # Mouse and human
    n2 = t.get_common_ancestor("NP_033532.1", "NP_005420.1")
    n2.swap_children()

    # Chimp and human
    n2 = t.get_common_ancestor("XP_526740.1", "NP_005420.1")
    n2.swap_children()

    # Add description to treefile
    #
    # Add fasta description line of outgroup sequence
    description_text = "Outgroup sequence: " + outgroup_name + "starfish VEGF-C\n"
    # Bootstrap analysis
    description_text += "Bootstrap: approximate Likelihood-Ratio Test\n"
    # Alignment methods
    description_text += "Alignment algorithm: m_coffee using" + "clustalw2, t_coffee, muscle, mafft, pcma, probcons"
    description_text += "\n"

    #ts.title.add_face(TextFace(description_text, fsize=12), column=0)
    t.render(SVG_TREEFILE, tree_style=ts, units="mm", h=240)
예제 #19
0
def plot_tree_barplot(tree_file,
                      taxon2value_list_barplot,
                      header_list,
                      taxon2set2value_heatmap=False,
                      header_list2=False,
                      column_scale=True,
                      general_max=False,
                      barplot2percentage=False,
                      taxon2mlst=False):
    '''

    display one or more barplot

    :param tree_file:
    :param taxon2value_list:
    :param exclude_outgroup:
    :param bw_scale:
    :param barplot2percentage: list of bool to indicates if the number are percentages and the range should be set to 0-100

    :return:
    '''

    import matplotlib.cm as cm
    from matplotlib.colors import rgb2hex
    import matplotlib as mpl

    if taxon2mlst:
        mlst_list = list(set(taxon2mlst.values()))
        mlst2color = dict(zip(mlst_list, get_spaced_colors(len(mlst_list))))
        mlst2color['-'] = 'white'

    if isinstance(tree_file, Tree):
        t1 = tree_file
    else:
        t1 = Tree(tree_file)

    # Calculate the midpoint node
    R = t1.get_midpoint_outgroup()
    # and set it as tree outgroup
    t1.set_outgroup(R)

    tss = TreeStyle()
    value = 1
    tss.draw_guiding_lines = True
    tss.guiding_lines_color = "gray"
    tss.show_leaf_name = False

    if column_scale and header_list2:
        import matplotlib.cm as cm
        from matplotlib.colors import rgb2hex
        import matplotlib as mpl
        column2scale = {}
        for column in header_list2:
            values = taxon2set2value_heatmap[column].values()

            norm = mpl.colors.Normalize(vmin=min(values), vmax=max(values))
            cmap = cm.OrRd
            m = cm.ScalarMappable(norm=norm, cmap=cmap)
            column2scale[column] = m

    cmap = cm.YlGnBu  #YlOrRd#OrRd

    values_lists = taxon2value_list_barplot.values()

    scale_list = []
    max_value_list = []

    for n, header in enumerate(header_list):
        #print 'scale', n, header
        data = [float(i[n]) for i in values_lists]

        if barplot2percentage is False:
            max_value = max(data)  #3424182#
            min_value = min(data)  #48.23
        else:
            if barplot2percentage[n] is True:
                max_value = 100
                min_value = 0
            else:
                max_value = max(data)  #3424182#
                min_value = min(data)  #48.23
        norm = mpl.colors.Normalize(vmin=min_value, vmax=max_value)
        m1 = cm.ScalarMappable(norm=norm, cmap=cmap)
        scale_list.append(m1)
        if not general_max:
            max_value_list.append(float(max_value))
        else:
            max_value_list.append(general_max)

    for i, lf in enumerate(t1.iter_leaves()):

        #if taxon2description[lf.name] == 'Pirellula staleyi DSM 6068':
        #    lf.name = 'Pirellula staleyi DSM 6068'
        #    continue
        if i == 0:

            col_add = 0

            if taxon2mlst:
                header_list = ['MLST'] + header_list

            for col, header in enumerate(header_list):

                #lf.add_face(n, column, position="aligned")
                n = TextFace(' ')
                n.margin_top = 1
                n.margin_right = 2
                n.margin_left = 2
                n.margin_bottom = 1
                n.rotation = 90
                n.inner_background.color = "white"
                n.opacity = 1.
                n.hz_align = 2
                n.vt_align = 2

                tss.aligned_header.add_face(n, col_add + 1)

                n = TextFace('%s' % header)
                n.margin_top = 1
                n.margin_right = 2
                n.margin_left = 2
                n.margin_bottom = 2
                n.rotation = 270
                n.inner_background.color = "white"
                n.opacity = 1.
                n.hz_align = 2
                n.vt_align = 1
                tss.aligned_header.add_face(n, col_add)
                col_add += 2

            if header_list2:
                for col, header in enumerate(header_list2):
                    n = TextFace('%s' % header)
                    n.margin_top = 1
                    n.margin_right = 20
                    n.margin_left = 2
                    n.margin_bottom = 1
                    n.rotation = 270
                    n.hz_align = 2
                    n.vt_align = 2
                    n.inner_background.color = "white"
                    n.opacity = 1.
                    tss.aligned_header.add_face(n, col + col_add)

        if taxon2mlst:

            try:
                #if lf.name in leaf2mlst or int(lf.name) in leaf2mlst:
                n = TextFace(' %s ' % taxon2mlst[int(lf.name)])
                n.inner_background.color = 'white'
                m = TextFace('  ')
                m.inner_background.color = mlst2color[taxon2mlst[int(lf.name)]]
            except:
                n = TextFace(' na ')
                n.inner_background.color = "grey"
                m = TextFace('    ')
                m.inner_background.color = "white"

            n.opacity = 1.
            n.margin_top = 2
            n.margin_right = 2
            n.margin_left = 0
            n.margin_bottom = 2

            m.margin_top = 2
            m.margin_right = 0
            m.margin_left = 2
            m.margin_bottom = 2

            lf.add_face(m, 0, position="aligned")
            lf.add_face(n, 1, position="aligned")
            col_add = 2
        else:
            col_add = 0

        try:
            val_list = taxon2value_list_barplot[lf.name]
        except:
            if not taxon2mlst:
                val_list = ['na'] * len(header_list)
            else:
                val_list = ['na'] * (len(header_list) - 1)

        for col, value in enumerate(val_list):

            # show value itself
            try:
                n = TextFace('  %s  ' % str(value))
            except:
                n = TextFace('  %s  ' % str(value))
            n.margin_top = 1
            n.margin_right = 5
            n.margin_left = 10
            n.margin_bottom = 1
            n.inner_background.color = "white"
            n.opacity = 1.

            lf.add_face(n, col_add, position="aligned")
            # show bar
            try:
                color = rgb2hex(scale_list[col].to_rgba(float(value)))
            except:
                color = 'white'
            try:
                percentage = (value / max_value_list[col]) * 100
                #percentage = value
            except:
                percentage = 0
            try:
                maximum_bar = (
                    (max_value_list[col] - value) / max_value_list[col]) * 100
            except:
                maximum_bar = 0
            #maximum_bar = 100-percentage
            b = StackedBarFace([percentage, maximum_bar],
                               width=100,
                               height=10,
                               colors=[color, "white"])
            b.rotation = 0
            b.inner_border.color = "grey"
            b.inner_border.width = 0
            b.margin_right = 15
            b.margin_left = 0
            lf.add_face(b, col_add + 1, position="aligned")
            col_add += 2

        if taxon2set2value_heatmap:
            shift = col + col_add + 1

            i = 0
            for col, col_name in enumerate(header_list2):
                try:
                    value = taxon2set2value_heatmap[col_name][lf.name]
                except:
                    try:
                        value = taxon2set2value_heatmap[col_name][int(lf.name)]
                    except:
                        value = 0

                if int(value) > 0:
                    if int(value) > 9:
                        n = TextFace(' %i ' % int(value))
                    else:
                        n = TextFace(' %i   ' % int(value))
                    n.margin_top = 1
                    n.margin_right = 1
                    n.margin_left = 20
                    n.margin_bottom = 1
                    n.fgcolor = "white"
                    n.inner_background.color = rgb2hex(
                        column2scale[col_name].to_rgba(
                            float(value)))  #"orange"
                    n.opacity = 1.
                    lf.add_face(n, col + col_add, position="aligned")
                    i += 1
                else:
                    n = TextFace('  ')  #% str(value))
                    n.margin_top = 1
                    n.margin_right = 1
                    n.margin_left = 20
                    n.margin_bottom = 1
                    n.inner_background.color = "white"
                    n.opacity = 1.

                    lf.add_face(n, col + col_add, position="aligned")

        n = TextFace(lf.name, fgcolor="black", fsize=12, fstyle='italic')
        lf.add_face(n, 0)

    for n in t1.traverse():
        nstyle = NodeStyle()
        if n.support < 1:
            nstyle["fgcolor"] = "black"
            nstyle["size"] = 6
            n.set_style(nstyle)
        else:
            nstyle["fgcolor"] = "red"
            nstyle["size"] = 0
            n.set_style(nstyle)

    return t1, tss
예제 #20
0
def draw_tree(the_tree, colour, back_color, label, out_file, the_scale, extend,
              bootstrap, group_file, grid_options, the_table, pres_abs,
              circular):
    t = Tree(the_tree, quoted_node_names=True)
    #    t.ladderize()
    font_size = 8
    font_type = 'Heveltica'
    font_gap = 3
    font_buffer = 10
    o = t.get_midpoint_outgroup()
    t.set_outgroup(o)
    the_leaves = []
    for leaves in t.iter_leaves():
        the_leaves.append(leaves)
    groups = {}
    num = 0
    # set cutoff value for clades as 1/20th of the distance between the furthest two branches
    # assign nodes to groups
    last_node = None
    ca_list = []
    if not group_file is None:
        style = NodeStyle()
        style['size'] = 0
        style["vt_line_color"] = '#000000'
        style["hz_line_color"] = '#000000'
        style["vt_line_width"] = 1
        style["hz_line_width"] = 1
        for n in t.traverse():
            n.set_style(style)
        with open(group_file) as f:
            group_dict = {}
            for line in f:
                group_dict[line.split()[0]] = line.split()[1]
        for node in the_leaves:
            i = node.name
            for j in group_dict:
                if j in i:
                    if group_dict[j] in groups:
                        groups[group_dict[j]].append(i)
                    else:
                        groups[group_dict[j]] = [i]
        coloured_nodes = []
        for i in groups:
            the_col = i
            style = NodeStyle()
            style['size'] = 0
            style["vt_line_color"] = the_col
            style["hz_line_color"] = the_col
            style["vt_line_width"] = 2
            style["hz_line_width"] = 2
            if len(groups[i]) == 1:
                ca = t.search_nodes(name=groups[i][0])[0]
                ca.set_style(style)
                coloured_nodes.append(ca)
            else:
                ca = t.get_common_ancestor(groups[i])
                ca.set_style(style)
                coloured_nodes.append(ca)
                tocolor = []
                for j in ca.children:
                    tocolor.append(j)
                while len(tocolor) > 0:
                    x = tocolor.pop(0)
                    coloured_nodes.append(x)
                    x.set_style(style)
                    for j in x.children:
                        tocolor.append(j)
            ca_list.append((ca, the_col))
        if back_color:
            # for each common ancestor node get it's closest common ancestor neighbour and find the common ancestor of those two nodes
            # colour the common ancestor then add it to the group - continue until only the root node is left
            while len(ca_list) > 1:
                distance = float('inf')
                for i, col1 in ca_list:
                    for j, col2 in ca_list:
                        if not i is j:
                            parent = t.get_common_ancestor(i, j)
                            getit = True
                            the_dist = t.get_distance(i, j)
                            if the_dist <= distance:
                                distance = the_dist
                                the_i = i
                                the_j = j
                                the_i_col = col1
                                the_j_col = col2
                ca_list.remove((the_i, the_i_col))
                ca_list.remove((the_j, the_j_col))
                rgb1 = strtorgb(the_i_col)
                rgb2 = strtorgb(the_j_col)
                rgb3 = ((rgb1[0] + rgb2[0]) / 2, (rgb1[1] + rgb2[1]) / 2,
                        (rgb1[2] + rgb2[2]) / 2)
                new_col = colorstr(rgb3)
                new_node = t.get_common_ancestor(the_i, the_j)
                the_col = new_col
                style = NodeStyle()
                style['size'] = 0
                style["vt_line_color"] = the_col
                style["hz_line_color"] = the_col
                style["vt_line_width"] = 2
                style["hz_line_width"] = 2
                new_node.set_style(style)
                coloured_nodes.append(new_node)
                ca_list.append((new_node, new_col))
                for j in new_node.children:
                    tocolor.append(j)
                while len(tocolor) > 0:
                    x = tocolor.pop(0)
                    if not x in coloured_nodes:
                        coloured_nodes.append(x)
                        x.set_style(style)
                        for j in x.children:
                            tocolor.append(j)
    elif colour:
        distances = []
        for node1 in the_leaves:
            for node2 in the_leaves:
                if node1 != node2:
                    distances.append(t.get_distance(node1, node2))
        distances.sort()
        clade_cutoff = distances[len(distances) / 4]
        for node in the_leaves:
            i = node.name
            if not last_node is None:
                if t.get_distance(node, last_node) <= clade_cutoff:
                    groups[group_num].append(i)
                else:
                    groups[num] = [num, i]
                    group_num = num
                    num += 1
            else:
                groups[num] = [num, i]
                group_num = num
                num += 1
            last_node = node
        for i in groups:
            num = groups[i][0]
            h = num * 360 / len(groups)
            the_col = hsl_to_str(h, 0.5, 0.5)
            style = NodeStyle()
            style['size'] = 0
            style["vt_line_color"] = the_col
            style["hz_line_color"] = the_col
            style["vt_line_width"] = 2
            style["hz_line_width"] = 2
            if len(groups[i]) == 2:
                ca = t.search_nodes(name=groups[i][1])[0]
                ca.set_style(style)
            else:
                ca = t.get_common_ancestor(groups[i][1:])
                ca.set_style(style)
                tocolor = []
                for j in ca.children:
                    tocolor.append(j)
                while len(tocolor) > 0:
                    x = tocolor.pop(0)
                    x.set_style(style)
                    for j in x.children:
                        tocolor.append(j)
            ca_list.append((ca, h))
        # for each common ancestor node get it's closest common ancestor neighbour and find the common ancestor of those two nodes
        # colour the common ancestor then add it to the group - continue until only the root node is left
        while len(ca_list) > 1:
            distance = float('inf')
            got_one = False
            for i, col1 in ca_list:
                for j, col2 in ca_list:
                    if not i is j:
                        parent = t.get_common_ancestor(i, j)
                        getit = True
                        for children in parent.children:
                            if children != i and children != j:
                                getit = False
                                break
                        if getit:
                            the_dist = t.get_distance(i, j)
                            if the_dist <= distance:
                                distance = the_dist
                                the_i = i
                                the_j = j
                                the_i_col = col1
                                the_j_col = col2
                                got_one = True
            if not got_one:
                break
            ca_list.remove((the_i, the_i_col))
            ca_list.remove((the_j, the_j_col))
            new_col = (the_i_col + the_j_col) / 2
            new_node = t.get_common_ancestor(the_i, the_j)
            the_col = hsl_to_str(new_col, 0.5, 0.3)
            style = NodeStyle()
            style['size'] = 0
            style["vt_line_color"] = the_col
            style["hz_line_color"] = the_col
            style["vt_line_width"] = 2
            style["hz_line_width"] = 2
            new_node.set_style(style)
            ca_list.append((new_node, new_col))
    # if you just want a black tree
    else:
        style = NodeStyle()
        style['size'] = 0
        style["vt_line_color"] = '#000000'
        style["hz_line_color"] = '#000000'
        style["vt_line_width"] = 1
        style["hz_line_width"] = 1
        for n in t.traverse():
            n.set_style(style)
    color_list = [(240, 163, 255), (0, 117, 220), (153, 63, 0), (76, 0, 92),
                  (25, 25, 25), (0, 92, 49), (43, 206, 72), (255, 204, 153),
                  (128, 128, 128), (148, 255, 181), (143, 124, 0),
                  (157, 204, 0), (194, 0, 136), (0, 51, 128), (255, 164, 5),
                  (255, 168, 187), (66, 102, 0), (255, 0, 16), (94, 241, 242),
                  (0, 153, 143), (224, 255, 102), (116, 10, 255), (153, 0, 0),
                  (255, 255, 128), (255, 255, 0), (255, 80, 5), (0, 0, 0),
                  (50, 50, 50)]
    up_to_colour = {}
    ts = TreeStyle()
    column_list = []
    width_dict = {}
    if not grid_options is None:
        colour_dict = {}
        type_dict = {}
        min_val_dict = {}
        max_val_dict = {}
        leaf_name_dict = {}
        header_count = 0
        the_columns = {}
        if grid_options == 'auto':
            with open(the_table) as f:
                headers = f.readline().rstrip().split('\t')[1:]
                for i in headers:
                    the_columns[i] = [i]
                    type_dict[i] = 'colour'
                    colour_dict[i] = {'empty': '#FFFFFF'}
                    width_dict[i] = 20
                    up_to_colour[i] = 0
                    column_list.append(i)
        else:
            with open(grid_options) as g:
                for line in g:
                    if line.startswith('H'):
                        name, type, width = line.rstrip().split('\t')[1:]
                        if name in the_columns:
                            the_columns[name].append(name + '_' +
                                                     str(header_count))
                        else:
                            the_columns[name] = [
                                name + '_' + str(header_count)
                            ]
                        width = int(width)
                        name = name + '_' + str(header_count)
                        header_count += 1
                        colour_dict[name] = {'empty': '#FFFFFF'}
                        type_dict[name] = type
                        width_dict[name] = width
                        column_list.append(name)
                        up_to_colour[name] = 0
                        min_val_dict[name] = float('inf')
                        max_val_dict[name] = 0
                    elif line.startswith('C'):
                        c_name, c_col = line.rstrip().split('\t')[1:]
                        if not c_col.startswith('#'):
                            c_col = colorstr(map(int, c_col.split(',')))
                        colour_dict[name][c_name] = c_col
        val_dict = {}
        with open(the_table) as f:
            headers = f.readline().rstrip().split('\t')[1:]
            column_no = {}
            for num, i in enumerate(headers):
                if i in the_columns:
                    column_no[num] = i
            for line in f:
                name = line.split('\t')[0]
                leaf_name = None
                for n in t.traverse():
                    if n.is_leaf():
                        if name.split('.')[0] in n.name:
                            leaf_name = n.name
                if leaf_name is None:
                    continue
                else:
                    leaf_name_dict[leaf_name] = name
                vals = line.rstrip().split('\t')[1:]
                if name in val_dict:
                    sys.exit('Duplicate entry found in table.')
                else:
                    val_dict[name] = {}
                for num, val in enumerate(vals):
                    if num in column_no and val != '':
                        for q in the_columns[column_no[num]]:
                            column_name = q
                            if type_dict[column_name] == 'colour':
                                val_dict[name][column_name] = val
                                if not val in colour_dict[column_name]:
                                    colour_dict[column_name][val] = colorstr(
                                        color_list[up_to_colour[column_name] %
                                                   len(color_list)])
                                    up_to_colour[column_name] += 1
                            elif type_dict[column_name] == 'text':
                                val_dict[name][column_name] = val
                            elif type_dict[column_name] == 'colour_scale_date':
                                year, month, day = val.split('-')
                                year, month, day = int(year), int(month), int(
                                    day)
                                the_val = datetime.datetime(
                                    year, month, day, 0, 0,
                                    0) - datetime.datetime(
                                        1970, 1, 1, 0, 0, 0)
                                val_dict[name][
                                    column_name] = the_val.total_seconds()
                                if the_val.total_seconds(
                                ) < min_val_dict[column_name]:
                                    min_val_dict[
                                        column_name] = the_val.total_seconds()
                                if the_val.total_seconds(
                                ) > max_val_dict[column_name]:
                                    max_val_dict[
                                        column_name] = the_val.total_seconds()
                            elif type_dict[column_name] == 'colour_scale':
                                the_val = float(val)
                                val_dict[name][column_name] = the_val
                                if the_val < min_val_dict[column_name]:
                                    min_val_dict[column_name] = the_val
                                if the_val > max_val_dict[column_name]:
                                    max_val_dict[column_name] = the_val
                            else:
                                sys.exit('Unknown column type')
        if not out_file is None:
            new_desc = open(out_file + '.new_desc', 'w')
        else:
            new_desc = open('viridis.new_desc', 'w')
        ts.legend_position = 3
        leg_column = 0
        for num, i in enumerate(column_list):
            nameF = TextFace(font_gap * ' ' + i.rsplit('_', 1)[0] +
                             ' ' * font_buffer,
                             fsize=font_size,
                             ftype=font_type,
                             tight_text=True)
            nameF.rotation = -90
            ts.aligned_header.add_face(nameF, column=num + 1)
            new_desc.write('H\t' + i.rsplit('_', 1)[0] + '\t' + type_dict[i] +
                           '\t' + str(width_dict[i]) + '\n')
            x = num * 200
            if type_dict[i] == 'colour':
                ts.legend.add_face(TextFace(
                    font_gap * ' ' + i.rsplit('_', 1)[0] + ' ' * font_buffer,
                    fsize=font_size,
                    ftype=font_type,
                    tight_text=True),
                                   column=leg_column + 1)
                ts.legend.add_face(RectFace(width_dict[i], 20, '#FFFFFF',
                                            '#FFFFFF'),
                                   column=leg_column)
                for num2, j in enumerate(colour_dict[i]):
                    new_desc.write('C\t' + j + '\t' + colour_dict[i][j] + '\n')
                    ts.legend.add_face(TextFace(font_gap * ' ' + j +
                                                ' ' * font_buffer,
                                                fsize=font_size,
                                                ftype=font_type,
                                                tight_text=True),
                                       column=leg_column + 1)
                    ts.legend.add_face(RectFace(width_dict[i], 20,
                                                colour_dict[i][j],
                                                colour_dict[i][j]),
                                       column=leg_column)
                leg_column += 2
            elif type_dict[i] == 'colour_scale':
                ts.legend.add_face(TextFace(
                    font_gap * ' ' + i.rsplit('_', 1)[0] + ' ' * font_buffer,
                    fsize=font_size,
                    ftype=font_type,
                    tight_text=True),
                                   column=leg_column + 1)
                ts.legend.add_face(RectFace(width_dict[i], 20, '#FFFFFF',
                                            '#FFFFFF'),
                                   column=leg_column)
                for num2 in range(11):
                    y = num2 * 20 + 30
                    val = (max_val_dict[i] - min_val_dict[i]) * num2 / 10.0
                    h = val / (max_val_dict[i] - min_val_dict[i]) * 270
                    s = 0.5
                    l = 0.5
                    colour = hsl_to_str(h, s, l)
                    ts.legend.add_face(TextFace(font_gap * ' ' + str(val) +
                                                ' ' * font_buffer,
                                                fsize=font_size,
                                                ftype=font_type,
                                                tight_text=True),
                                       column=leg_column + 1)
                    ts.legend.add_face(RectFace(width_dict[i], 20, colour,
                                                colour),
                                       column=leg_column)
                leg_column += 2
            elif type_dict[i] == 'colour_scale_date':
                ts.legend.add_face(TextFace(
                    font_gap * ' ' + i.rsplit('_', 1)[0] + ' ' * font_buffer,
                    fsize=font_size,
                    ftype=font_type,
                    tight_text=True),
                                   column=leg_column + 1)
                ts.legend.add_face(RectFace(width_dict[i], 20, '#FFFFFF',
                                            '#FFFFFF'),
                                   column=leg_column)
                for num2 in range(11):
                    y = num2 * 20 + 30
                    val = (max_val_dict[i] - min_val_dict[i]) * num2 / 10.0
                    h = val / (max_val_dict[i] - min_val_dict[i]) * 360
                    s = 0.5
                    l = 0.5
                    colour = hsl_to_str(h, s, l)
                    days = str(int(val / 60 / 60 / 24)) + ' days'
                    ts.legend.add_face(TextFace(font_gap * ' ' + days +
                                                ' ' * font_buffer,
                                                fsize=font_size,
                                                ftype=font_type,
                                                tight_text=True),
                                       column=leg_column + 1)
                    ts.legend.add_face(RectFace(width_dict[i], 20, colour,
                                                colour),
                                       column=leg_column)
                leg_column += 2
            for n in t.traverse():
                if n.is_leaf():
                    name = leaf_name_dict[n.name]
                    if i in val_dict[name]:
                        val = val_dict[name][i]
                    else:
                        val = 'empty'
                    if type_dict[i] == 'colour':
                        n.add_face(RectFace(width_dict[i], 20,
                                            colour_dict[i][val],
                                            colour_dict[i][val]),
                                   column=num + 1,
                                   position="aligned")
                    elif type_dict[i] == 'colour_scale' or type_dict[
                            i] == 'colour_scale_date':
                        if val == 'empty':
                            colour = '#FFFFFF'
                        else:
                            h = (val - min_val_dict[i]) / (
                                max_val_dict[i] - min_val_dict[i]) * 360
                            s = 0.5
                            l = 0.5
                            colour = hsl_to_str(h, s, l)
                        n.add_face(RectFace(width_dict[i], 20, colour, colour),
                                   column=num + 1,
                                   position="aligned")
                    elif type_dict[i] == 'text':
                        n.add_face(TextFace(font_gap * ' ' + val +
                                            ' ' * font_buffer,
                                            fsize=font_size,
                                            ftype=font_type,
                                            tight_text=True),
                                   column=num + 1,
                                   position="aligned")
    if not pres_abs is None:
        starting_col = len(column_list) + 1
        subprocess.Popen('makeblastdb -out tempdb -dbtype prot -in ' +
                         pres_abs[0],
                         shell=True).wait()
        folder = pres_abs[1]
        len_dict = {}
        gene_list = []
        ts.legend.add_face(TextFace(font_gap * ' ' + 'Gene present/absent' +
                                    ' ' * font_buffer,
                                    fsize=font_size,
                                    ftype=font_type,
                                    tight_text=True),
                           column=starting_col + 1)
        ts.legend.add_face(RectFace(20, 20, '#FFFFFF', '#FFFFFF'),
                           column=starting_col)
        ts.legend.add_face(TextFace(font_gap * ' ' + 'Gene present/absent' +
                                    ' ' * font_buffer,
                                    fsize=font_size,
                                    ftype=font_type,
                                    tight_text=True),
                           column=starting_col + 1)
        ts.legend.add_face(RectFace(20, 20, "#5ba965", "#5ba965"),
                           column=starting_col)
        ts.legend.add_face(TextFace(font_gap * ' ' + 'Gene present/absent' +
                                    ' ' * font_buffer,
                                    fsize=font_size,
                                    ftype=font_type,
                                    tight_text=True),
                           column=starting_col + 1)
        ts.legend.add_face(RectFace(20, 20, "#cb5b4c", "#cb5b4c"),
                           column=starting_col)
        with open(pres_abs[0]) as f:
            for line in f:
                if line.startswith('>'):
                    name = line.split()[0][1:]
                    gene_list.append(name)
                    len_dict[name] = 0
                    nameF = TextFace(font_gap * ' ' + name + ' ' * font_buffer,
                                     fsize=font_size,
                                     ftype=font_type,
                                     tight_text=True)
                    nameF.rotation = -90
                    ts.aligned_header.add_face(nameF,
                                               column=starting_col +
                                               len(gene_list) - 1)
                else:
                    len_dict[name] += len(line.rstrip())
        min_length = 0.9
        min_ident = 90
        for n in t.iter_leaves():
            the_name = n.name
            if the_name[0] == '"' and the_name[-1] == '"':
                the_name = the_name[1:-1]
            if the_name.endswith('.ref'):
                the_name = the_name[:-4]
            if not os.path.exists(folder + '/' + the_name):
                for q in os.listdir(folder):
                    if q.startswith(the_name):
                        the_name = q
            if not os.path.exists(the_name + '.blast'):
                subprocess.Popen(
                    'blastx -query ' + folder + '/' + the_name +
                    ' -db tempdb -outfmt 6 -num_threads 24 -out ' + the_name +
                    '.blast',
                    shell=True).wait()
            gotit = set()
            with open(the_name + '.blast') as b:
                for line in b:
                    query, subject, ident, length = line.split()[:4]
                    ident = float(ident)
                    length = int(length)
                    if ident >= min_ident and length >= min_length * len_dict[
                            subject]:
                        gotit.add(subject)
            for num, i in enumerate(gene_list):
                if i in gotit:
                    colour = "#5ba965"
                else:
                    colour = "#cb5b4c"
                n.add_face(RectFace(20, 20, colour, colour),
                           column=num + starting_col,
                           position="aligned")
        # for num, i in enumerate(gene_list):
        #     x = (starting_col + num) * 200
        #     svg.writeString(i, x+50, 20, 12)
        #     y = 30
        #     svg.drawOutRect(x + 50, y, 12, 12, strtorgb('#5ba965'), strtorgb('#5ba965'), lt=0)
        #     svg.writeString('present', x + 70, y + 12, 12)
        #     y = 50
        #     svg.drawOutRect(x + 50, y, 12, 12, strtorgb('#cb5b4c'), strtorgb('#cb5b4c'), lt=0)
        #     svg.writeString('absent', x + 70, y + 12, 12)

    # Set these to False if you don't want bootstrap/distance values
    ts.show_branch_length = label
    ts.show_branch_support = bootstrap
    ts.show_leaf_name = False
    for node in t.traverse():
        if node.is_leaf():
            node.add_face(AttrFace("name",
                                   fsize=font_size,
                                   ftype=font_type,
                                   tight_text=True,
                                   fgcolor='black'),
                          column=0,
                          position="aligned")

    ts.margin_left = 20
    ts.margin_right = 100
    ts.margin_top = 20
    ts.margin_bottom = 20
    if extend:
        ts.draw_guiding_lines = True
    ts.scale = the_scale
    if not circular is None:
        ts.mode = "c"
        ts.arc_start = 0
        ts.arc_span = 360
    if out_file is None:
        t.show(tree_style=ts)
    else:
        t.render(out_file, w=210, units='mm', tree_style=ts)
예제 #21
0
                    default="pies.svg")

args = parser.parse_args()
plot_tree, subtrees_dict, subtrees_topids = get_phyparts_nodes(
    args.species_tree, args.phyparts_root)

concord_dict, conflict_dict = get_concord_and_conflict(args.phyparts_root,
                                                       subtrees_dict,
                                                       subtrees_topids)
phyparts_dist, phyparts_pies = get_pie_chart_data(args.phyparts_root,
                                                  args.num_genes, concord_dict,
                                                  conflict_dict)

#Plot Pie Chart
ts = TreeStyle()
ts.show_leaf_name = False

ts.layout_fn = phyparts_pie_layout
nstyle = NodeStyle()
nstyle["size"] = 0
for n in plot_tree.traverse():
    n.set_style(nstyle)
    n.img_style["vt_line_width"] = 0

ts.draw_guiding_lines = True
ts.guiding_lines_color = "black"
ts.guiding_lines_type = 0
ts.scale = 30
ts.branch_vertical_margin = 10
plot_tree.convert_to_ultrametric()
my_svg = plot_tree.render(args.svg_name, tree_style=ts, w=595)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-t",
                        action="store",
                        dest="treef",
                        help="The tree file")
    parser.add_argument("-s",
                        action="store",
                        dest="statf",
                        help="The statistics file")
    parser.add_argument("-c",
                        action="store",
                        default=0,
                        type=float,
                        dest="scaling",
                        help="If this parameter is set "
                        "to more than 0, "
                        "the size of the pie charts "
                        "correlate with the total "
                        "number of events at a node "
                        "(and are scaled by the factor "
                        "given as a float).")
    parser.add_argument(
        "-e",
        action="store",
        type=str,
        default="all",
        dest="event",
        help=
        "If an event type is specified, just this event type is visualized. Per default all event types are shown on the tree.\n"
        "all\n"
        "fusions\n"
        "fissions\n"
        "termLosses\n"
        "termEmergences\n"
        "singleDomainLosses\n"
        "singleDomainEmergences")
    parser.add_argument(
        "-p",
        action="store",
        dest="treeshape",
        default="r",
        choices=["c", "r"],
        help="shape of the tree, circle (c) or tree format (r)")
    parser.add_argument("-o", action="store", dest="outputname")
    parser.add_argument(
        "-y",
        action="store",
        type=str,
        dest="NodeIDtreeName",
        default=None,
        help="Name for output file that shows a tree with all node IDs.")
    parser.add_argument(
        "-l",
        dest="short_legend",
        help=
        "Writes the full legend for all events in two levels for short trees",
        action="store_true")

    params = parser.parse_args()

    if (params.event not in ("all", "fusions", "fissions", "termLosses",
                             "termEmergences", "singleDomainLosses",
                             "singleDomainEmergences")):
        print(
            "Error: Please specify a valid event type. For a list of possible options use the --help parameter."
        )
        sys.exit(1)

    if params.NodeIDtreeName != None:
        id_tree = Tree(params.treef, format=0)

        coun = 0
        for node in id_tree.traverse('preorder'):
            node.add_features(ID=coun)
            coun += 1

        # Create empty TreeStyle
        ts = TreeStyle()

        # Set custom layout function
        ts.layout_fn = layout_idtree
        # Draw tree
        ts.mode = params.treeshape
        ts.complete_branch_lines_when_necessary = True
        ts.extra_branch_line_type = 0
        ts.extra_branch_line_color = "black"
        # ts.optimal_scale_level ="full"
        ts.branch_vertical_margin = 40
        ts.scale = 100

        # We will add node names manually
        ts.show_leaf_name = False
        ts.draw_guiding_lines = True

        if (params.NodeIDtreeName.endswith(".pdf")):
            pathout = params.NodeIDtreeName
        else:
            pathout = params.NodeIDtreeName + ".pdf"
        id_tree.render(pathout, dpi=1200, tree_style=ts)
        pl.close()

    else:

        tree = Tree(params.treef, format=0)

        # Read statistics file
        node_stat_dict = {}
        with open(params.statf, "r") as sf:
            for line in sf:
                # Stop the loop at the second part of statistics file
                if line.startswith(
                        "# Number of events per domain."
                ) or line.startswith(
                        "# Events per domain arrangement for last common ancestor"
                ):
                    break
                if line[0] not in ('#', '\n'):
                    vecline = line.strip().split()
                    id = vecline.pop(0)
                    stats = [int(i) for i in vecline]
                    node_stat_dict[int(id)] = stats

        # determine max. number of events per node for scaling
        fus_max = 0
        fis_max = 0
        termLoss_max = 0
        termGain_max = 0
        singLoss_max = 0
        singGain_max = 0
        tot_max = 0

        # Assign rearrangement events to leaves
        c = 0
        for node in tree.traverse('preorder'):
            node.add_features(fusion=node_stat_dict[c][0])
            if (node_stat_dict[c][0] > fus_max):
                fus_max = node_stat_dict[c][0]
            node.add_features(fission=node_stat_dict[c][1])
            if (node_stat_dict[c][1] > fis_max):
                fis_max = node_stat_dict[c][1]
            node.add_features(termLoss=node_stat_dict[c][2])
            if (node_stat_dict[c][2] > termLoss_max):
                termLoss_max = node_stat_dict[c][2]
            node.add_features(termGain=node_stat_dict[c][3])
            if (node_stat_dict[c][3] > termGain_max):
                termGain_max = node_stat_dict[c][3]
            node.add_features(singLoss=node_stat_dict[c][4])
            if (node_stat_dict[c][4] > singLoss_max):
                singLoss_max = node_stat_dict[c][4]
            node.add_features(singGain=node_stat_dict[c][5])
            if (node_stat_dict[c][5] > singGain_max):
                singGain_max = node_stat_dict[c][5]
            if (sum(node_stat_dict[c]) > tot_max):
                tot_max = sum(node_stat_dict[c])
            c += 1

        global scal
        scal = params.scaling

        global eve
        event_options = {
            "all": 0,
            "fusions": 1,
            "fissions": 2,
            "termLosses": 3,
            "termEmergences": 4,
            "singleDomainLosses": 5,
            "singleDomainEmergences": 6
        }
        eve = event_options[params.event]

        # Create empty TreeStyle
        ts = TreeStyle()

        # Set custom layout function
        ts.layout_fn = layout_gen_events
        # Draw tree
        ts.mode = params.treeshape
        ts.complete_branch_lines_when_necessary = True
        ts.extra_branch_line_type = 0
        ts.extra_branch_line_color = "black"
        #ts.optimal_scale_level ="full"
        ts.branch_vertical_margin = 40
        ts.scale = 100

        # We will add node names manually
        ts.show_leaf_name = False

        # legend creation
        if (params.event == "all"):
            ts.legend.add_face(CircleFace(10, "DimGray"), column=0)
            ts.legend.add_face(TextFace(" Fusion     ",
                                        fsize=16,
                                        fgcolor='DimGray'),
                               column=1)
            ts.legend.add_face(CircleFace(10, "DeepPink"), column=2)
            ts.legend.add_face(TextFace(' Fission     ',
                                        fsize=16,
                                        fgcolor='DeepPink'),
                               column=3)
            ts.legend.add_face(CircleFace(10, "YellowGreen"), column=4)
            ts.legend.add_face(TextFace(' Terminal Loss     ',
                                        fsize=16,
                                        fgcolor='YellowGreen'),
                               column=5)
            if params.short_legend:
                ts.legend.add_face(CircleFace(10, "DarkBlue"), column=0)
                ts.legend.add_face(TextFace(' Terminal Emergence     ',
                                            fsize=16,
                                            fgcolor='DarkBlue'),
                                   column=1)
                ts.legend.add_face(CircleFace(10, "Chocolate"), column=2)
                ts.legend.add_face(TextFace(' Single Domain Loss     ',
                                            fsize=16,
                                            fgcolor='Chocolate'),
                                   column=3)
                ts.legend.add_face(CircleFace(10, "DeepSkyBlue"), column=4)
                ts.legend.add_face(TextFace(' Single Domain Emergence     ',
                                            fsize=16,
                                            fgcolor='DeepSkyBlue'),
                                   column=5)
            else:
                ts.legend.add_face(CircleFace(10, "DarkBlue"), column=6)
                ts.legend.add_face(TextFace(' Terminal Emergence     ',
                                            fsize=16,
                                            fgcolor='DarkBlue'),
                                   column=7)
                ts.legend.add_face(CircleFace(10, "Chocolate"), column=8)
                ts.legend.add_face(TextFace(' Single Domain Loss     ',
                                            fsize=16,
                                            fgcolor='Chocolate'),
                                   column=9)
                ts.legend.add_face(CircleFace(10, "DeepSkyBlue"), column=10)
                ts.legend.add_face(TextFace(' Single Domain Emergence     ',
                                            fsize=16,
                                            fgcolor='DeepSkyBlue'),
                                   column=11)
        elif (params.event == "fusions"):
            ts.legend.add_face(CircleFace(10, "DimGray"), column=0)
            ts.legend.add_face(TextFace(" Fusion     ",
                                        fsize=16,
                                        fgcolor='DimGray'),
                               column=1)
        elif (params.event == "fissions"):
            ts.legend.add_face(CircleFace(10, "DeepPink"), column=0)
            ts.legend.add_face(TextFace(' Fission     ',
                                        fsize=16,
                                        fgcolor='DeepPink'),
                               column=1)
        elif (params.event == "termLosses"):
            ts.legend.add_face(CircleFace(10, "YellowGreen"), column=0)
            ts.legend.add_face(TextFace(' Terminal Loss     ',
                                        fsize=16,
                                        fgcolor='YellowGreen'),
                               column=1)
        elif (params.event == "termEmergences"):
            ts.legend.add_face(CircleFace(10, "DarkBlue"), column=0)
            ts.legend.add_face(TextFace(' Terminal Emergence     ',
                                        fsize=16,
                                        fgcolor='DarkBlue'),
                               column=1)
        elif (params.event == "singleDomainLosses"):
            ts.legend.add_face(CircleFace(10, "Chocolate"), column=0)
            ts.legend.add_face(TextFace(' Single Domain Loss     ',
                                        fsize=16,
                                        fgcolor='Chocolate'),
                               column=1)
        elif (params.event == "singleDomainEmergences"):
            ts.legend.add_face(CircleFace(10, "DeepSkyBlue"), column=0)
            ts.legend.add_face(TextFace(' Single Domain Emergence     ',
                                        fsize=16,
                                        fgcolor='DeepSkyBlue'),
                               column=1)

        ts.legend_position = 1
        ts.draw_guiding_lines = True

        if (params.outputname.endswith(".pdf")):
            pathout = params.outputname
        else:
            pathout = params.outputname + ".pdf"
        tree.render(pathout, dpi=1200, tree_style=ts)
        pl.close()

    sys.exit(0)