Exemplo n.º 1
0
def make_color_legend(filename,
                      colormap,
                      start,
                      end,
                      step,
                      width=100,
                      height=10,
                      display=False):
    from rasmus import util

    if filename is None:
        filename = util.tempfile(".", "colormap", ".svg")
        temp = True
    else:
        temp = False

    s = svg.Svg(util.open_stream(filename, "w"))
    s.beginSvg(width, height)

    xscale = float(width) / (end + step - start)

    for i in util.frange(start, end + step, step):
        color = colormap.get(i)
        s.rect((i - start) * xscale, 0, step * xscale, height, color, color)

    s.endSvg()
    s.close()

    # display
    if display:
        os.system("display %s" % filename)

    # clean up temp files
    if temp:
        os.remove(filename)
Exemplo n.º 2
0
def drawDistRuler(names,
                  dists,
                  scale=500,
                  padding=10,
                  textsize=12,
                  notchsize=2,
                  labelpadding=5,
                  distsize=9,
                  filename=sys.stdout):
    """Produce a ruler of pairwise distances"""

    nameswidth = textsize * max(map(len, names))

    out = svg.Svg(util.open_stream(filename, "w"))
    out.beginSvg(scale * max(dists) + 2 * padding,
                 2 * padding + nameswidth + 5 * distsize)

    out.beginTransform(("translate", padding, nameswidth + padding))

    # draw ruler
    out.line(0, 0, scale * max(dists), 0)

    for name, dist in zip(names, dists):
        x = scale * dist
        out.text(name, x + textsize / 2.0, -labelpadding, textsize, angle=-90)
        out.line(x, notchsize, x, -notchsize)
        out.text("%.3f" % dist,
                 x + textsize / 2.0,
                 labelpadding + distsize * 3.5,
                 distsize,
                 angle=-90)

    out.endTransform()
    out.endSvg()
Exemplo n.º 3
0
def make_color_legend(filename, colormap, start, end, step, 
                    width=100, height=10):
    from rasmus import util
    s = svg.Svg(util.open_stream(filename, "w"))    
    s.beginSvg(width, height)
    
    xscale =  float(width) / (end + step - start)
    
    for i in util.frange(start, end + step, step):
        color = colormap.get(i)
        s.rect((i-start) * xscale, 
               0, 
               step*xscale, height, 
               color, color)
    
    s.endSvg()
Exemplo n.º 4
0
def draw_tree(tree,
              brecon,
              stree,
              xscale=100,
              yscale=100,
              leaf_padding=10,
              label_size=None,
              label_offset=None,
              font_size=12,
              stree_font_size=20,
              canvas=None,
              autoclose=True,
              rmargin=10,
              lmargin=100,
              tmargin=100,
              bmargin=100,
              tree_color=(0, 0, 0),
              tree_trans_color=(0, 0, 0),
              stree_color=(.3, .7, .3),
              snode_color=(.2, .2, .7),
              loss_color=(1, 1, 1),
              loss_color_border=(.5, .5, .5),
              dup_color=(0, 0, 1),
              dup_color_border=(0, 0, 1),
              trans_color=(1, 1, 0),
              trans_color_border=(.5, .5, 0),
              gtrans_color=(1, 0, 0),
              gtrans_color_border=(.5, 0, 0),
              event_size=10,
              snames=None,
              rootlen=None,
              stree_width=.8,
              filename="tree.svg"):
    '''Takes as input a parasite tree, tree, a reconciliation file, brecon, a host tree, stree, as well as
    sizes and colors of the trees components and returns a drawing of the reconciliation of the parasite 
    tree on the host tree with event nodes of specified colors'''
    # set defaults
    font_ratio = 8. / 11.
    if label_size is None:
        label_size = .7 * font_size

    if sum(x.dist for x in tree.nodes.values()) == 0:
        legend_scale = False
        minlen = xscale

    if snames is None:
        snames = dict((x, x) for x in stree.leaf_names())

    # layout stree
    slayout = treelib1.layout_tree(stree, xscale, yscale)
    if rootlen is None:
        rootlen = .1 * max(l[0] for l in slayout.values())

    # setup slayout
    x, y = slayout[stree.root]
    slayout[None] = (x - rootlen, y)
    for node, (x, y) in slayout.items():
        slayout[node] = (x + rootlen, y - .5 * yscale)

    # layout tree
    ylists = defaultdict(lambda: [])
    yorders = {}
    # layout speciations and genes (y)
    for node in tree.preorder():
        if node == list(tree.preorder())[0]:
            rootNode = node.name
        yorders[node] = []
        for ev in brecon[node]:
            snode, event, frequency = ev
            if event == "spec" or event == "gene" or event == "loss":
                yorders[node].append(len(ylists[snode]))
                ylists[snode].append(node)

    # layout dups and transfers (y)
    for node in tree.postorder():

        for ev in brecon[node]:
            snode, event, frequency = ev
            if event != "spec" and event != "gene" and event != "loss":
                # Find number of nodes on a single branch for y-coord
                v = [
                    yorders[child] for child in node.children
                    if brecon[child][-1][0] == snode
                ]
                if len(v) == 0:
                    yorders[node].append(0)
                else:
                    yorders[node].append(stats.mean(flatten(v)))

    # layout node (x)
    xorders = {
    }  #Dictionary to record number of nodes on a single branch for x-coord
    branchFrac = {}  #Dictionary to record the placement of a node on a branch
    for node in tree.postorder():
        for n in range(len(brecon[node])):
            snode, event, frequency = brecon[node][n]
            if event == "spec" or event == "gene" or event == "loss":
                # Speciation, gene, and loss events happen at host vertices
                if not node in branchFrac:
                    branchFrac[node] = 0
            else:  # Transfers and duplications occur on branches
                v = [branchFrac[child] for child in node.children]
                if len(v) == 0:
                    branchFrac[node] = 1
                else:
                    branchFrac[node] = max(v) + 1

    for node in tree.preorder():
        xorders[node] = []
        for n in range(len(brecon[node])):
            snode, event, frequency = brecon[node][n]
            if event == "spec" or event == "gene" or event == "loss":
                # Speciation, gene, and loss events happen on vertices, not branches
                xorders[node].append(0)
            else:
                if node.parent and containsTransOrDup(node.parent, brecon):
                    # set branchFrac to the branch Frac of the parent, they are
                    # on the same branch
                    branchFrac[node] = branchFrac[node.parent]
                if containsLoss(node, brecon):
                    # if following a loss, first transfer/duplication event on branch
                    xorders[node].append(1)
                elif not node.parent:  # Root of tree
                    xorders[node].append(0)
                else:
                    xorders[node].append(maxList(xorders[node.parent]) + 1)

    # setup layout
    layout = {None: [slayout[brecon[tree.root][-1][0].parent]]}
    for node in tree.preorder():
        for n in range(len(brecon[node])):
            snode, event, frequency = brecon[node][n]
            nx, ny = slayout[snode]
            px, py = slayout[snode.parent]
            (npx, npy) = layout[node.parent][-1]
            # set spacing between nodes on the same branch
            frac = 50
            while branchFrac[node] * frac >= nx - px:
                frac = frac - 5

        # calc x
            if event == "trans" or event == "gtrans":
                if npx > px:  # transfer parent is farther forward in time than host parent
                    x = npx + frac
                else:
                    x = px + frac
            elif event == "dup":
                x = px + frac
            else:
                x = nx
            # calc y

            deltay = ny - py
            slope = deltay / float(nx - px)

            deltax2 = x - px
            deltay2 = slope * deltax2
            offset = py + deltay2
            frac = (yorders[node][n] +
                    1) / float(max(len(ylists[snode]), 1) + 1)
            y = offset + (frac - .5) * stree_width * yscale

            if node in layout: layout[node].append((x, y))
            else:
                layout[node] = [(x, y)]

        # order brecon nodes temporally
        brecon[node] = orderLoss(node, brecon, layout)
        # order layout nodes temporally
        layout[node] = orderLayout(node, layout)

        if y > max(l[1] for l in slayout.values()) + 50:
            print nx, ny
            print px, py
            print offset, frac
            print ylists[snode], yorders[node]
            print brecon[node]
            print node, snode, layout[node]

    # layout label sizes
    max_label_size = max(len(x.name)
                         for x in tree.leaves()) * font_ratio * font_size
    max_slabel_size = max(
        len(x.name) for x in stree.leaves()) * font_ratio * stree_font_size
    '''
    if colormap == None:
        for node in tree:
            node.color = (0, 0, 0)
    else:
        colormap(tree)
    
    if stree and gene2species:
        recon = phylo.reconcile(tree, stree, gene2species)
        events = phylo.label_events(tree, recon)
        losses = phylo.find_loss(tree, stree, recon)
    else:
        events = None
        losses = None
    
    # layout tree
    if layout is None:
        coords = treelib.layout_tree(tree, xscale, yscale, minlen, maxlen)
    else:
        coords = layout
    '''

    xcoords, ycoords = zip(*slayout.values())
    maxwidth = max(xcoords) + max_label_size + max_slabel_size
    maxheight = max(ycoords) + yscale

    # initialize canvas
    if canvas is None:
        canvas = svg.Svg(util.open_stream(filename, "w"))
        width = int(rmargin + maxwidth + lmargin)
        height = int(tmargin + maxheight + bmargin)

        canvas.beginSvg(width, height)
        canvas.beginStyle("font-family: \"Sans\";")

        if autoclose == None:
            autoclose = True
    else:
        if autoclose == None:
            autoclose = False

    canvas.beginTransform(("translate", lmargin, tmargin))

    draw_stree(canvas,
               stree,
               slayout,
               yscale=yscale,
               stree_width=stree_width,
               stree_color=stree_color,
               snode_color=snode_color)

    # draw stree leaves
    for node in stree:
        x, y = slayout[node]
        if node.is_leaf():
            canvas.text(snames[node.name],
                        x + leaf_padding + max_label_size,
                        y + stree_font_size / 2.,
                        stree_font_size,
                        fillColor=snode_color)

    # draw tree

    for node in tree:

        containsL = containsLoss(node, brecon)
        for n in range(len(brecon[node])):
            x, y = layout[node][n]

            if containsL == False:  # no loss event
                px, py = layout[node.parent][-1]
            else:  # loss event present
                if n == 0:  # event is loss
                    px, py = layout[node.parent][-1]
                else:  # event stems from loss
                    px, py = layout[node][n - 1]

            trans = False

            if node.parent:
                snode, event, frequency = brecon[node][n]
                if n == 0:
                    psnode, pevent, pfrequency = brecon[node.parent][-1]

                # Event stemming from a loss event
                else:
                    psnode, pevent, pfrequency = brecon[node][n - 1]
                if pevent == "trans" or pevent == "gtrans":
                    if psnode != snode:
                        trans = True
                else:
                    trans = False

                if not trans:
                    canvas.line(x, y, px, py, color=tree_color)

                # draw the transfer dashed line
                else:
                    arch = 20
                    x2 = (x * .5 + px * .5) - arch
                    y2 = (y * .5 + py * .5)
                    x3 = (x * .5 + px * .5) - arch
                    y3 = (y * .5 + py * .5)
                    # draw regular transfer dashed line
                    if pevent == "trans":
                        canvas.write(
                            "<path d='M%f %f C%f %f %f %f %f %f' %s />\n " %
                            (x, y, x2, y2, x3, y3, px, py,
                             " style='stroke-dasharray: 4, 2' " +
                             svg.colorFields(tree_trans_color, (0, 0, 0, 0))))
                    # draw guilty transfer dashed line
                    else:
                        canvas.write(
                            "<path d='M%f %f C%f %f %f %f %f %f' %s />\n " %
                            (x, y, x2, y2, x3, y3, px, py,
                             " style='stroke-dasharray: 4, 2' " +
                             svg.colorFields(gtrans_color, (0, 0, 0, 0))))

    # draw events
    for node in tree:
        if node.name == rootNode:
            x, y = layout[node][0]
            canvas.polygon((x-20, y, x-50, y+30,x-50, y+15, x-90, y+15, x-90,\
             y-15, x-50, y-15, x-50, y-30), strokeColor = (1,.7,.3), \
             fillColor = (1,.7,.3))

            canvas.text("Root Node", x-88, y+5, font_size+2,\
                fillColor = (0,0,0))
        for n in range(len(brecon[node])):
            snode, event, frequency = brecon[node][n]
            frequency = float(frequency)
            x, y = layout[node][n]
            o = event_size / 2.0
            if event == "loss":  # draw boxes, frequencies of loss events
                canvas.rect(x - o,
                            y - o,
                            event_size,
                            event_size,
                            fillColor=loss_color,
                            strokeColor=loss_color_border)
                canvas.text("{:.3f}".format(frequency) + node.name,
                            x - o,
                            y - o,
                            font_size + 2,
                            fillColor=loss_color)

            if event == "spec":  # draw boxes, frequencies of speciation events
                canvas.rect(x - o,
                            y - o,
                            event_size,
                            event_size,
                            fillColor=(0, 0, 0),
                            strokeColor=(0, 0, 0))
                canvas.text("{:.3f}".format(frequency) + node.name,
                            x - o,
                            y - o,
                            font_size + 2,
                            fillColor=(0, 0, 0))

            if event == "dup":  # draw boxes, frequencies of duplication events
                canvas.rect(x - o,
                            y - o,
                            event_size,
                            event_size,
                            fillColor=dup_color,
                            strokeColor=dup_color_border)
                canvas.text("{:.3f}".format(frequency) + node.name,
                            x - o,
                            y - o,
                            font_size + 2,
                            fillColor=dup_color)

            elif event == "trans":  # draw boxes, frequencies of transfer events
                canvas.rect(x - o,
                            y - o,
                            event_size,
                            event_size,
                            fillColor=trans_color,
                            strokeColor=trans_color_border)
                canvas.text("{:.3f}".format(frequency) + node.name,
                            x - o,
                            y - o,
                            font_size + 2,
                            fillColor=trans_color)

            elif event == "gtrans":  # draw boxes, frequencies of guilty transfer events
                canvas.rect(x - o,
                            y - o,
                            event_size,
                            event_size,
                            fillColor=gtrans_color,
                            strokeColor=gtrans_color_border)
                canvas.text("{:.3f}".format(frequency) + node.name,
                            x - o,
                            y - o,
                            font_size + 2,
                            fillColor=gtrans_color)

    # draw tree leaves
    for node in tree:
        for n in range(len(brecon[node])):
            x, y = layout[node][n]
            if node.is_leaf() and brecon[node][n][1] == "gene":
                canvas.text(node.name,
                            x + leaf_padding,
                            y + font_size / 2.,
                            font_size + 2,
                            fillColor=(0, 0, 0))

    canvas.endTransform()

    if autoclose:
        canvas.endStyle()
        canvas.endSvg()

    return canvas
Exemplo n.º 5
0
def draw_tree(tree,
              labels={},
              xscale=100,
              yscale=20,
              canvas=None,
              leafPadding=10,
              leafFunc=lambda x: str(x.name),
              labelOffset=None,
              fontSize=10,
              labelSize=None,
              minlen=1,
              maxlen=util.INF,
              filename=sys.stdout,
              rmargin=150,
              lmargin=10,
              tmargin=0,
              bmargin=None,
              colormap=None,
              stree=None,
              layout=None,
              gene2species=None,
              lossColor=(0, 0, 1),
              dupColor=(1, 0, 0),
              eventSize=4,
              legendScale=False,
              autoclose=None,
              extendRoot=True,
              labelLeaves=True,
              drawHoriz=True,
              nodeSize=0):

    # set defaults
    fontRatio = 8. / 11.

    if labelSize == None:
        labelSize = .7 * fontSize

    if labelOffset == None:
        labelOffset = -1

    if bmargin == None:
        bmargin = yscale

    if sum(x.dist for x in tree.nodes.values()) == 0:
        legendScale = False
        minlen = xscale

    if colormap == None:
        for node in tree:
            node.color = (0, 0, 0)
    else:
        colormap(tree)

    if stree and gene2species:
        recon = phylo.reconcile(tree, stree, gene2species)
        events = phylo.label_events(tree, recon)
        losses = phylo.find_loss(tree, stree, recon)
    else:
        events = None
        losses = None

    if len(labels) > 0 or (stree and gene2species):
        drawHoriz = True

    # layout tree
    if layout is None:
        coords = treelib.layout_tree(tree, xscale, yscale, minlen, maxlen)
    else:
        coords = layout

    xcoords, ycoords = zip(*coords.values())
    maxwidth = max(xcoords)
    maxheight = max(ycoords) + labelOffset

    # initialize canvas
    if canvas == None:
        canvas = svg.Svg(util.open_stream(filename, "w"))
        width = int(rmargin + maxwidth + lmargin)
        height = int(tmargin + maxheight + bmargin)

        canvas.beginSvg(width, height)

        if autoclose == None:
            autoclose = True
    else:
        if autoclose == None:
            autoclose = False

    # draw tree
    def walk(node):
        x, y = coords[node]
        if node.parent:
            parentx, parenty = coords[node.parent]
        else:
            if extendRoot:
                parentx, parenty = 0, y
            else:
                parentx, parenty = x, y  # e.g. no branch

        # draw branch
        if drawHoriz:
            canvas.line(parentx, y, x, y, color=node.color)
        else:
            canvas.line(parentx, parenty, x, y, color=node.color)

        # draw branch labels
        if node.name in labels:
            branchlen = x - parentx
            lines = str(labels[node.name]).split("\n")
            labelwidth = max(map(len, lines))
            labellen = min(labelwidth * fontRatio * fontSize,
                           max(int(branchlen - 1), 0))

            for i, line in enumerate(lines):
                canvas.text(
                    line, parentx + (branchlen - labellen) / 2.,
                    y + labelOffset + (-len(lines) + 1 + i) * (labelSize + 1),
                    labelSize)

        # draw nodes
        if nodeSize > 0:
            canvas.circle(x,
                          y,
                          nodeSize,
                          strokeColor=svg.null,
                          fillColor=node.color)

        # draw leaf labels or recur
        if node.is_leaf():
            if labelLeaves:
                canvas.text(leafFunc(node),
                            x + leafPadding,
                            y + fontSize / 2.,
                            fontSize,
                            fillColor=node.color)
        else:
            if drawHoriz:
                # draw vertical part of branch
                top = coords[node.children[0]][1]
                bot = coords[node.children[-1]][1]
                canvas.line(x, top, x, bot, color=node.color)

            # draw children
            for child in node.children:
                walk(child)

    canvas.beginTransform(("translate", lmargin, tmargin))
    walk(tree.root)

    if stree and gene2species:
        draw_events(canvas,
                    tree,
                    coords,
                    events,
                    losses,
                    lossColor=lossColor,
                    dupColor=dupColor,
                    size=eventSize)
    canvas.endTransform()

    # draw legend
    if legendScale:
        if legendScale == True:
            # automatically choose a scale
            length = maxwidth / float(xscale)
            order = math.floor(math.log10(length))
            length = 10**order

        drawScale(lmargin,
                  tmargin + maxheight + bmargin - fontSize,
                  length,
                  xscale,
                  fontSize,
                  canvas=canvas)

    if autoclose:
        canvas.endSvg()

    return canvas
Exemplo n.º 6
0
def heatmap(matrix,
            width=20,
            height=20,
            colormap=None,
            filename=None,
            rlabels=None,
            clabels=None,
            display=True,
            xdir=1,
            ydir=1,
            xmargin=0,
            ymargin=0,
            labelPadding=2,
            labelSpacing=4,
            mincutoff=None,
            maxcutoff=None,
            showVals=False,
            formatVals=str,
            valColor=black,
            clabelsAngle=270,
            clabelsPadding=None,
            rlabelsAngle=0,
            rlabelsPadding=None,
            colors=None,
            strokeColors=None,
            valAnchor="start",
            close=True):

    from rasmus import util
    if display and (not close):
        raise Exception("must close file if display is used")

    # determine filename
    if filename is None:
        filename = util.tempfile(".", "heatmap", ".svg")
        temp = True
    else:
        temp = False

    # determine colormap
    if colors is None:
        if colormap is None:
            colormap = rainbowColorMap(util.flatten(matrix))

    # determine matrix size and orientation
    nrows = len(matrix)
    ncols = len(matrix[0])

    if xdir == 1:
        xstart = xmargin
        ranchor = "end"
        coffset = width
    elif xdir == -1:
        xstart = xmargin + ncols * width
        ranchor = "start"
        coffset = 0
    else:
        raise Exception("xdir must be 1 or -1")

    if ydir == 1:
        ystart = ymargin
        roffset = height
        canchor = "start"
    elif ydir == -1:
        ystart = ymargin + nrows * width
        roffset = 0
        canchor = "end"
    else:
        raise Exception("ydir must be 1 or -1")

    # begin svg
    infile = util.open_stream(filename, "w")
    s = svg.Svg(infile)
    s.beginSvg(ncols * width + 2 * xmargin, nrows * height + 2 * ymargin)

    # draw matrix
    for i in xrange(nrows):
        for j in xrange(ncols):

            if mincutoff and matrix[i][j] < mincutoff:
                continue
            if maxcutoff and matrix[i][j] > maxcutoff:
                continue

            if colors:
                color = colors[i][j]
            else:
                color = colormap.get(matrix[i][j])

            if strokeColors:
                strokeColor = strokeColors[i][j]
            else:
                strokeColor = color

            s.rect(xstart + xdir * j * width, ystart + ydir * i * height,
                   xdir * width, ydir * height, strokeColor, color)

    # draw values
    if showVals:
        # find text size

        fontwidth = 7 / 11.0

        textsize = []
        for i in xrange(nrows):
            for j in xrange(ncols):

                if mincutoff and matrix[i][j] < mincutoff:
                    continue
                if maxcutoff and matrix[i][j] > maxcutoff:
                    continue

                strval = formatVals(matrix[i][j])
                if len(strval) > 0:
                    textsize.append(
                        min(height, width / (float(len(strval)) * fontwidth)))
        textsize = min(textsize)

        if valAnchor == "start":
            xoffset = 0
        elif valAnchor == "middle":
            xoffset = 0.5
        elif valAnchor == "end":
            xoffset = 1
        else:
            raise Exception("anchor not supported: %s" % valAnchor)

        yoffset = int(ydir == -1)
        for i in xrange(nrows):
            for j in xrange(ncols):

                if mincutoff and matrix[i][j] < mincutoff:
                    continue
                if maxcutoff and matrix[i][j] > maxcutoff:
                    continue

                strval = formatVals(matrix[i][j])
                s.text(strval,
                       xstart + xdir * (j + xoffset) * width,
                       ystart + ydir * (i + yoffset) * height + height / 2.0 +
                       textsize / 2.0,
                       textsize,
                       fillColor=valColor,
                       anchor=valAnchor)

    # draw labels
    if rlabels is not None:
        assert len(rlabels) == nrows, \
            "number of row labels does not equal number of rows"

        if rlabelsPadding is None:
            rlabelsPadding = labelPadding

        for i in xrange(nrows):
            x = xstart - xdir * rlabelsPadding
            y = ystart + roffset + ydir * i * height - labelSpacing / 2.
            s.text(rlabels[i],
                   x,
                   y,
                   height - labelSpacing,
                   anchor=ranchor,
                   angle=rlabelsAngle)

    if clabels is not None:
        assert len(clabels) == ncols, \
            "number of col labels does not equal number of cols"

        if clabelsPadding is None:
            clabelsPadding = labelPadding

        for j in xrange(ncols):
            x = xstart + coffset + xdir * j * width - labelSpacing / 2.
            y = ystart - ydir * clabelsPadding
            s.text(clabels[j],
                   x,
                   y,
                   width - labelSpacing,
                   anchor=canchor,
                   angle=clabelsAngle)

    # end svg
    if close:
        s.endSvg()
        s.close()

    # display matrix
    if display:
        #if temp:
        os.system("display %s" % filename)
    #else:
    #    os.spawnl(os.P_NOWAIT, "display", "display", filename)

    # clean up temp files
    if temp:
        os.remove(filename)

    return s
Exemplo n.º 7
0
def draw_tree(tree,
              brecon,
              stree,
              xscale=100,
              yscale=100,
              leaf_padding=10,
              label_size=None,
              label_offset=None,
              font_size=12,
              stree_font_size=20,
              canvas=None,
              autoclose=True,
              rmargin=10,
              lmargin=10,
              tmargin=0,
              bmargin=0,
              tree_color=(0, 0, 0),
              tree_trans_color=(0, 0, 0),
              stree_color=(.6, .3, .8),
              snode_color=(.2, .2, .7),
              loss_color=(1, 1, 1),
              loss_color_border=(.5, .5, .5),
              dup_color=(1, 0, 0),
              dup_color_border=(.5, 0, 0),
              trans_color=(0, 1, 0),
              trans_color_border=(0, .5, 0),
              event_size=10,
              snames=None,
              rootlen=None,
              stree_width=.8,
              filename="tree.svg"):

    # set defaults
    font_ratio = 8. / 11.

    if label_size is None:
        label_size = .7 * font_size

    #if label_offset is None:
    #    label_offset = -1

    if sum(x.dist for x in tree.nodes.values()) == 0:
        legend_scale = False
        minlen = xscale

    if snames is None:
        snames = dict((x, x) for x in stree.leaf_names())

    # layout stree
    slayout = treelib1.layout_tree(stree, xscale, yscale)
    if rootlen is None:
        rootlen = .1 * max(l[0] for l in slayout.values())

    # setup slayout
    x, y = slayout[stree.root]
    slayout[None] = (x - rootlen, y)
    for node, (x, y) in slayout.items():
        slayout[node] = (x + rootlen, y - .5 * yscale)

    # layout tree
    ylists = defaultdict(lambda: [])
    yorders = {}

    # layout speciations and genes (y)
    for node in tree.preorder():
        for ev in brecon[node]:
            snode, event, frequency = ev
            if event == "spec" or event == "gene" or event == "loss":
                yorders[node] = len(ylists[snode])
                ylists[snode].append(node)
    # layout dups and transfers (y)
    for node in tree.postorder():
        for ev in brecon[node]:
            snode, event, frequency = ev
            if event != "spec" and event != "gene" and event != "loss":
                v = [
                    yorders[child] for child in node.children
                    if brecon[child][-1][0] == snode
                ]
                if len(v) == 0:
                    yorders[node] = 0
                else:
                    yorders[node] = stats.mean(v)

    # layout node (x)
    xorders = {}
    xmax = defaultdict(lambda: 0)
    for node in tree.postorder():
        for ev in brecon[node]:
            snode, event, frequency = ev
            if event == "spec" or event == "gene" or event == "loss":
                xorders[node] = 0
            else:
                v = [
                    xorders[child] for child in node.children
                    if brecon[child][-1][0] == snode
                ]
                if len(v) == 0:
                    xorders[node] = 1
                else:
                    xorders[node] = max(v) + 1
            xmax[snode] = max(xmax[snode], xorders[node])

    # setup layout
    layout = {None: [slayout[brecon[tree.root][-1][0].parent]]}
    for node in tree:
        for ev in brecon[node]:
            snode, event, frequency = ev
            nx, ny = slayout[snode]
            px, py = slayout[snode.parent]

            # calc x
            frac = (xorders[node]) / float(xmax[snode] + 1)
            deltax = nx - px
            x = nx - frac * deltax

            # calc y
            deltay = ny - py
            slope = deltay / float(deltax)
            deltax2 = x - px
            deltay2 = slope * deltax2
            offset = py + deltay2

            frac = (yorders[node] + 1) / float(max(len(ylists[snode]), 1) + 1)
            y = offset + (frac - .5) * stree_width * yscale

            if node in layout: layout[node].append((x, y))
            else:
                layout[node] = [(x, y)]
        brecon[node] = orderLoss(node, brecon, layout)
        print "Brecon = ", brecon[node]
        layout[node] = orderLayout(node, layout)
        print "Layout = ", layout[node]
        if y > max(l[1] for l in slayout.values()) + 50:
            print nx, ny
            print px, py
            print offset, frac
            print ylists[snode], yorders[node]
            print brecon[node]
            print node, snode, layout[node]

    # layout label sizes
    max_label_size = max(len(x.name)
                         for x in tree.leaves()) * font_ratio * font_size
    max_slabel_size = max(
        len(x.name) for x in stree.leaves()) * font_ratio * stree_font_size
    '''
    if colormap == None:
        for node in tree:
            node.color = (0, 0, 0)
    else:
        colormap(tree)
    
    if stree and gene2species:
        recon = phylo.reconcile(tree, stree, gene2species)
        events = phylo.label_events(tree, recon)
        losses = phylo.find_loss(tree, stree, recon)
    else:
        events = None
        losses = None
    
    # layout tree
    if layout is None:
        coords = treelib.layout_tree(tree, xscale, yscale, minlen, maxlen)
    else:
        coords = layout
    '''

    xcoords, ycoords = zip(*slayout.values())
    maxwidth = max(xcoords) + max_label_size + max_slabel_size
    maxheight = max(ycoords) + .5 * yscale

    # initialize canvas
    if canvas is None:
        canvas = svg.Svg(util.open_stream(filename, "w"))
        width = int(rmargin + maxwidth + lmargin)
        height = int(tmargin + maxheight + bmargin)

        canvas.beginSvg(width, height)
        canvas.beginStyle("font-family: \"Sans\";")

        if autoclose == None:
            autoclose = True
    else:
        if autoclose == None:
            autoclose = False

    canvas.beginTransform(("translate", lmargin, tmargin))

    draw_stree(canvas,
               stree,
               slayout,
               yscale=yscale,
               stree_width=stree_width,
               stree_color=stree_color,
               snode_color=snode_color)

    # draw stree leaves
    for node in stree:
        x, y = slayout[node]
        if node.is_leaf():
            canvas.text(snames[node.name],
                        x + leaf_padding + max_label_size,
                        y + stree_font_size / 2.,
                        stree_font_size,
                        fillColor=snode_color)

    # draw tree

    for node in tree:
        containsL = containsLoss(node, brecon)
        for n in range(len(brecon[node])):
            # print brecon[node]
            x, y = layout[node][n]
            # print layout[node]
            if containsL == False:
                px, py = layout[node.parent][-1]
            else:
                if brecon[node][n][1] == "loss":
                    px, py = layout[node.parent][-1]
                else:
                    px, py = layout[node][n - 1]

            trans = False

            if node.parent:
                for ev in brecon[node]:
                    snode, event, frequency = ev
                    psnode = brecon[node.parent][-1][0]
                while snode:
                    if psnode == snode:
                        break
                    snode = snode.parent
                else:
                    trans = True

            if not trans:
                canvas.line(x, y, px, py, color=tree_color)
            else:
                arch = 20
                x2 = (x * .5 + px * .5) - arch
                y2 = (y * .5 + py * .5)
                x3 = (x * .5 + px * .5) - arch
                y3 = (y * .5 + py * .5)

                canvas.write("<path d='M%f %f C%f %f %f %f %f %f' %s />\n " %
                             (x, y, x2, y2, x3, y3, px, py,
                              " style='stroke-dasharray: 4, 2' " +
                              svg.colorFields(tree_trans_color, (0, 0, 0, 0))))

    # draw events
    for node in tree:
        for n in range(len(brecon[node])):
            snode, event, frequency = brecon[node][n]
            x, y = layout[node][n]
            o = event_size / 2.0
            if event == "loss":
                canvas.rect(x - o,
                            y - o,
                            event_size,
                            event_size,
                            fillColor=loss_color,
                            strokeColor=loss_color_border)
                canvas.text(frequency,
                            x - o,
                            y - o,
                            font_size,
                            fillColor=(1, 1, 1))

            if event == "spec":
                canvas.text(frequency,
                            slayout[snode][0] - leaf_padding / 2,
                            slayout[snode][1] - font_size,
                            font_size,
                            fillColor=(0, 0, 0))

            if event == "dup":
                canvas.rect(x - o,
                            y - o,
                            event_size,
                            event_size,
                            fillColor=dup_color,
                            strokeColor=dup_color_border)
                canvas.text(frequency,
                            x - o,
                            y - o,
                            font_size,
                            fillColor=dup_color)
            elif event == "trans":
                canvas.rect(x - o,
                            y - o,
                            event_size,
                            event_size,
                            fillColor=trans_color,
                            strokeColor=trans_color_border)
                canvas.text(frequency,
                            x - o,
                            y - o,
                            font_size,
                            fillColor=trans_color)

    # draw tree leaves
    for node in tree:
        for n in range(len(brecon[node])):
            x, y = layout[node][n]
            if node.is_leaf() and containsLoss(node, brecon) == False:
                canvas.text(node.name,
                            x + leaf_padding,
                            y + font_size / 2.,
                            font_size,
                            fillColor=(0, 0, 0))

    canvas.endTransform()

    if autoclose:
        canvas.endStyle()
        canvas.endSvg()

    return canvas
Exemplo n.º 8
0
def draw_tree(tree,
              stree,
              extra,
              xscale=100,
              yscale=100,
              leaf_padding=10,
              label_size=None,
              label_offset=None,
              font_size=12,
              stree_font_size=20,
              canvas=None,
              autoclose=True,
              rmargin=10,
              lmargin=10,
              tmargin=0,
              bmargin=0,
              stree_color=(.4, .4, 1),
              snode_color=(.2, .2, .7),
              event_size=10,
              rootlen=None,
              stree_width=.8,
              filename=sys.stdout,
              labels=None,
              slabels=None):

    recon = extra["species_map"]
    loci = extra["locus_map"]
    order = extra["order"]

    # setup color map
    all_loci = sorted(set(loci.values()))
    num_loci = len(all_loci)
    colormap = util.rainbow_color_map(low=0, high=num_loci - 1)
    locus_color = {}
    for ndx, locus in enumerate(all_loci):
        locus_color[locus] = colormap.get(ndx)

    # set defaults
    font_ratio = 8. / 11.

    if label_size is None:
        label_size = .7 * font_size

    #if label_offset is None:
    #    label_offset = -1

    if sum(x.dist for x in tree.nodes.values()) == 0:
        legend_scale = False
        minlen = xscale

    snames = dict((x, x) for x in stree.leaf_names())

    if labels is None:
        labels = {}
    if slabels is None:
        slabels = {}

    # layout stree
    slayout = treelib.layout_tree(stree, xscale, yscale)

    if rootlen is None:
        rootlen = .1 * max(l[0] for l in slayout.values())

    # setup slayout
    x, y = slayout[stree.root]
    slayout[None] = (x - rootlen, y)
    for node, (x, y) in slayout.items():
        slayout[node] = (x + rootlen, y - .5 * yscale)

    # layout tree
    ylists = defaultdict(lambda: [])
    yorders = {}

    # layout speciations and genes (y)
    events = phylo.label_events(tree, recon)
    for node in tree.preorder():
        snode = recon[node]
        event = events[node]
        if event == "spec" or event == "gene":
            yorders[node] = len(ylists[snode])
            ylists[snode].append(node)

    # layout internal nodes (y)
    for node in tree.postorder():
        snode = recon[node]
        event = events[node]
        if event != "spec" and event != "gene":
            v = [yorders[child] for child in node.children]
            yorders[node] = stats.mean(v)

    # layout node (x)

    xorders = {}
    xmax = defaultdict(lambda: 0)
    for node in tree.postorder():
        snode = recon[node]
        event = events[node]
        if event == "spec" or event == "gene":
            xorders[node] = 0
        else:
            v = [xorders[child] for child in node.children]
            xorders[node] = max(v) + 1
        xmax[snode] = max(xmax[snode], xorders[node])

##    # initial order
##    xpreorders = {}
##    for node in tree.postorder():
##        snode = recon[node]
##        event = events[node]
##        if event == "spec" or event == "gene":
##            xpreorders[node] = 0
##        else:
##            v = [xpreorders[child] for child in node.children]
##            xpreorders[node] = max(v) + 1
####        print node.name, xpreorders[node]
##    # hack-ish approach : shift x until order is satisfied
##    def shift(node, x):
##        xpreorders[node] += x
##        for child in node.children:
##            if events[child] != "spec":
##                shift(child, x)
##    satisfied = False
##    while not satisfied:
##        satisfied = True
##        for snode, d in order.iteritems():
##            for plocus, lst in d.iteritems():
##                # test each pair
##                for m, node1 in enumerate(lst):
##                    x1 = xpreorders[node1]
##                    for node2 in lst[m+1:]:
##                        x2 = xpreorders[node2]
####                        print node1, node2, x1, x2
##                        if x2 < x1:
##                            # violation - shift all descendants in the sbranch
##                            satisfied = False
####                            print 'violation', node1, node2, x1, x2, x1-x2+1
##                            shift(node2, x1-x2+1)
##                            break
##    # finally, "normalize" xorders
##    xorders = {}
##    xmax = defaultdict(lambda: 0)
##    for node in tree.postorder():
##        snode = recon[node]
##        xorders[node] = xpreorders[node]
##        xmax[snode] = max(xmax[snode], xorders[node])
####        print node.name, xpreorders[node]

# setup layout
    layout = {None: slayout[None]}
    for node in tree:
        snode = recon[node]
        nx, ny = slayout[snode]
        px, py = slayout[snode.parent]

        # calc x
        frac = (xorders[node]) / float(xmax[snode] + 1)
        deltax = nx - px
        x = nx - frac * deltax

        # calc y
        deltay = ny - py
        slope = deltay / float(deltax)
        deltax2 = x - px
        deltay2 = slope * deltax2
        offset = py + deltay2

        frac = (yorders[node] + 1) / float(max(len(ylists[snode]), 1) + 1)
        y = offset + (frac - .5) * stree_width * yscale

        layout[node] = (x, y)

##        if y > max(l[1] for l in slayout.values()) + 50:
##            print nx, ny
##            print px, py
##            print offset, frac
##            print ylists[snode], yorders[node]
##            print node, snode, layout[node]

# layout label sizes
    max_label_size = max(len(x.name)
                         for x in tree.leaves()) * font_ratio * font_size
    max_slabel_size = max(
        len(x.name) for x in stree.leaves()) * font_ratio * stree_font_size

    xcoords, ycoords = zip(*slayout.values())
    maxwidth = max(xcoords) + max_label_size + max_slabel_size
    maxheight = max(ycoords) + .5 * yscale

    # initialize canvas
    if canvas is None:
        canvas = svg.Svg(util.open_stream(filename, "w"))
        width = int(rmargin + maxwidth + lmargin)
        height = int(tmargin + maxheight + bmargin)

        canvas.beginSvg(width, height)
        canvas.beginStyle("font-family: \"Sans\";")

        if autoclose == None:
            autoclose = True
    else:
        if autoclose == None:
            autoclose = False

    canvas.beginTransform(("translate", lmargin, tmargin))

    draw_stree(canvas,
               stree,
               slayout,
               yscale=yscale,
               stree_width=stree_width,
               stree_color=stree_color,
               snode_color=snode_color,
               slabels=slabels)

    # draw stree leaves
    for node in stree:
        x, y = slayout[node]
        if node.is_leaf():
            canvas.text(snames[node.name],
                        x + leaf_padding + max_label_size,
                        y + stree_font_size / 2.,
                        stree_font_size,
                        fillColor=snode_color)

    # draw tree
    for node in tree:
        x, y = layout[node]
        px, py = layout[node.parent]

        if node.parent:
            color = locus_color[loci[node.parent]]
        else:
            color = locus_color[loci[tree.root]]

        canvas.line(x, y, px, py, color=color)

    # draw tree names
    for node in tree:
        x, y = layout[node]
        px, py = layout[node.parent]

        if node.is_leaf():
            canvas.text(node.name,
                        x + leaf_padding,
                        y + font_size / 2.,
                        font_size,
                        fillColor=(0, 0, 0))

        if node.name in labels:
            canvas.text(labels[node.name],
                        x,
                        y,
                        label_size,
                        fillColor=(0, 0, 0))

    # draw events
    for node in tree:
        if node.parent:
            locus = loci[node]
            plocus = loci[node.parent]

            if locus != plocus:
                color = locus_color[locus]
                x, y = layout[node]
                o = event_size / 2.0

                canvas.rect(x - o,
                            y - o,
                            event_size,
                            event_size,
                            fillColor=color,
                            strokeColor=color)

    canvas.endTransform()

    if autoclose:
        canvas.endStyle()
        canvas.endSvg()

    return canvas