Example #1
0
 def __init__(self, length=1, extent = 0.03, label="", loc=2, ax=None,
              pad=0.4, borderpad=0.5, ppad = 0, sep=2, prop=None, 
              frameon=True, linekw={}, **kwargs):
     if not ax:
         ax = plt.gca()
     trans = ax.get_xaxis_transform()
     size_bar = offbox.AuxTransformBox(trans)
     line = Line2D([0,length],[0,0], **linekw)
     size_bar.add_artist(line)
     txt = offbox.TextArea(label, minimumdescent=False, 
                           textprops=dict(color="black",size=14, fontweight='bold'))
     self.vpac = offbox.VPacker(children=[size_bar,txt],  
                              align="center", pad=ppad, sep=sep) 
     offbox.AnchoredOffsetbox.__init__(self, loc, pad=pad, 
              borderpad=borderpad, child=self.vpac, prop=prop, frameon=frameon,
              **kwargs)
Example #2
0
def plot_model(model,
               save_file=None,
               ax=None,
               show=False,
               fig_title="Demographic Model",
               pop_labels=None,
               nref=None,
               draw_ancestors=True,
               draw_migrations=True,
               draw_scale=True,
               arrow_size=0.01,
               transition_size=0.05,
               gen_time=0,
               gen_time_units="Years",
               reverse_timeline=False,
               fig_bg_color='#ffffff',
               plot_bg_color='#ffffff',
               text_color='#002b36',
               gridline_color='#586e75',
               pop_color='#268bd2',
               arrow_color='#073642',
               label_size=16,
               tick_size=12,
               grid=True):
    """
    Plots a demographic model based on information contained within a _ModelInfo
    object. See the matplotlib docs for valid entries for the color parameters.

    model : A _ModelInfo object created using generate_model().

    save_file : If not None, the figure will be saved to this location. Otherwise
                the figure will be displayed to the screen.

    fig_title : Title of the figure.

    pop_labels : If not None, should be a list of strings of the same length as
                 the total number of final populations in the model. The string
                 at index i should be the name of the population along axis i in
                 the model's SFS.

    nref : If specified, this will update the time and population size labels to
           use units based on an ancestral population size of nref. See the
           documentation for details.

    draw_ancestors : Specify whether the ancestral populations should be drawn
                     in beginning of plot. Will fade off with a gradient.
    
    draw_migrations : Specify whether migration arrows are drawn.

    draw_scale : Specify whether scale bar should be shown in top-left corner.

    arrow_size : Float to control the size of the migration arrows.

    transition_size : Float specifying size of the "transitional periods"
                      between populations.

    gen_time : If greater than 0, and nref given, timeline will be adjusted to
               show absolute time values, using this value as the time elapsed
               per generation.

    gen_time_units : Units used for gen_time (e.g. Years, Thousand Years, etc.).

    reverse_timeline : If True, the labels on the timeline will be reversed, so
                       that "0 time" is the present time period, rather than the
                       time of the original population.

    fig_bg_color : Background color of figure (i.e. border surrounding the
                   drawn model).

    plot_bg_color : Background color of the actual plot area.

    text_color : Color of text in the figure.

    gridline_color : Color of the plot gridlines.

    pop_color : Color of the populations.

    arrow_color : Color of the arrows showing migrations between populations.
    """
    # Set up the plot with a title and axis labels
    fig_kwargs = {
        'figsize': (9.6, 5.4),
        'dpi': 200,
        'facecolor': fig_bg_color,
        'edgecolor': fig_bg_color
    }
    if ax == None:
        fig = plt.figure(**fig_kwargs)
        ax = fig.add_subplot(111)
    ax.set_facecolor(plot_bg_color)
    ax.set_title(fig_title, color=text_color, fontsize=24)
    xlabel = "Time Ago" if reverse_timeline else "Time"
    if nref:
        if gen_time > 0:
            xlabel += " ({})".format(gen_time_units)
        else:
            xlabel += " (Generations)"
        ylabel = "Population Sizes"
    else:
        xlabel += " (Genetic Units)"
        ylabel = "Relative Population Sizes"
    ax.set_xlabel(xlabel, color=text_color, fontsize=label_size)
    ax.set_ylabel(ylabel, color=text_color, fontsize=label_size)

    # Determine various maximum values for proper scaling within the plot
    xmax = model.tp_list[-1].time[-1]
    ymax = sum(model.tp_list[0].framesizes)
    ax.set_xlim([-1 * xmax * 0.1, xmax])
    ax.set_ylim([0, ymax])
    mig_max = 0
    for tp in model.tp_list:
        if tp.migrations is None:
            continue
        mig = np.amax(tp.migrations)
        mig_max = mig_max if mig_max > mig else mig

    # Configure axis border colors
    ax.spines['top'].set_color(text_color)
    ax.spines['right'].set_color(text_color)
    ax.spines['bottom'].set_color(text_color)
    ax.spines['left'].set_color(text_color)

    # Major ticks along x-axis (time) placed at each population split
    xticks = [tp.time[0] for tp in model.tp_list]
    xticks.append(xmax)
    ax.xaxis.set_major_locator(mticker.FixedLocator(xticks))
    ax.xaxis.set_minor_locator(mticker.NullLocator())
    ax.tick_params(which='both',
                   axis='x',
                   labelcolor=text_color,
                   labelsize=tick_size,
                   top=False)
    # Choose correct time labels based on nref, gen_time, and reverse_timeline
    if reverse_timeline:
        xticks = [xmax - x for x in xticks]
    if nref:
        if gen_time > 0:
            xticks = [2 * nref * gen_time * x for x in xticks]
        else:
            xticks = [2 * nref * x for x in xticks]
        ax.set_xticklabels(['{:.0f}'.format(x) for x in xticks])
    else:
        ax.set_xticklabels(['{:.2f}'.format(x) for x in xticks])

    # Gridlines along y-axis (population size) spaced by nref size
    if grid:
        ax.yaxis.set_major_locator(mticker.FixedLocator(np.arange(ymax)))
        ax.yaxis.set_minor_locator(mticker.NullLocator())
        ax.grid(b=True, which='major', axis='y', color=gridline_color)
        ax.tick_params(which='both',
                       axis='y',
                       colors='none',
                       labelsize=tick_size)
    else:
        ax.set_yticks([])

    # Add scale in top-left corner displaying ancestral population size (Nref)
    if draw_scale:
        # Bidirectional arrow of height Nref
        arrow = mbox.AuxTransformBox(ax.transData)
        awidth = xmax * arrow_size * 0.2
        alength = ymax * arrow_size
        arrow_kwargs = {
            'width': awidth,
            'head_width': awidth * 3,
            'head_length': alength,
            'color': text_color,
            'length_includes_head': True
        }
        arrow.add_artist(
            plt.arrow(0, 0.25, 0, 0.75, zorder=100, **arrow_kwargs))
        arrow.add_artist(
            plt.arrow(0, 0.75, 0, -0.75, zorder=100, **arrow_kwargs))
        # Population bar of height Nref
        bar = mbox.AuxTransformBox(ax.transData)
        bar.add_artist(
            mpatches.Rectangle((0, 0), xmax / ymax, 1, color=pop_color))
        # Appropriate label depending on scale
        label = mbox.TextArea(str(nref) if nref else "Nref")
        label.get_children()[0].set_color(text_color)
        bars = mbox.HPacker(children=[label, arrow, bar],
                            pad=0,
                            sep=2,
                            align="center")
        scalebar = mbox.AnchoredOffsetbox(2,
                                          pad=0.25,
                                          borderpad=0.25,
                                          child=bars,
                                          frameon=False)
        ax.add_artist(scalebar)

    # Add ancestral populations using a gradient fill.
    if draw_ancestors:
        time = -1 * xmax * 0.1
        for i, ori in enumerate(model.tp_list[0].origins):
            # Draw ancestor for each initial pop
            xlist = np.linspace(time, 0.0, model.precision)
            dx = xlist[1] - xlist[0]
            low, mid, top = (ori[1], ori[1] + 1.0,
                             ori[1] + model.tp_list[0].popsizes[i][0])
            tsize = int(transition_size * model.precision)
            y1list = np.array([low] * model.precision)
            y2list = np.array([mid] * (model.precision - tsize))
            y2list = np.append(y2list, np.linspace(mid, top, tsize))
            # Custom color map runs from bg color to pop color
            cmap = mcolors.LinearSegmentedColormap.from_list(
                "custom_map", [plot_bg_color, pop_color])
            colors = np.array(
                cmap(np.linspace(0.0, 1.0, model.precision - tsize)))
            # Gradient created by drawing multiple small rectangles
            for x, y1, y2, color in zip(xlist[:-1 * tsize],
                                        y1list[:-1 * tsize],
                                        y2list[:-1 * tsize], colors):
                rect = mpatches.Rectangle((x, y1), dx, y2 - y1, color=color)
                ax.add_patch(rect)
            ax.fill_between(xlist[-1 * tsize:],
                            y1list[-1 * tsize:],
                            y2list[-1 * tsize:],
                            color=pop_color,
                            edgecolor=pop_color)

    # Iterate through time periods and populations to draw everything
    for tp_index, tp in enumerate(model.tp_list):
        # Keep track of migrations to evenly space arrows across time period
        total_migrations = np.count_nonzero(tp.migrations)
        num_migrations = 0

        for pop_index in range(len(tp.popsizes)):
            # Draw current population
            origin = tp.origins[pop_index]
            popsize = tp.popsizes[pop_index]
            direc = tp.direcs[pop_index]
            y1 = origin[1]
            y2 = origin[1] + (direc * popsize)
            ax.fill_between(tp.time,
                            y1,
                            y2,
                            color=pop_color,
                            edgecolor=pop_color)

            # Draw connections to next populations if necessary
            if tp.descendants is not None and tp.descendants[pop_index] != -1:
                desc = tp.descendants[pop_index]
                tp_next = model.tp_list[tp_index + 1]
                # Split population case
                if isinstance(desc, tuple):
                    # Get origins
                    connect_below = tp_next.origins[desc[0]][1]
                    connect_above = tp_next.origins[desc[1]][1]
                    # Get popsizes
                    subpop_below = tp_next.popsizes[desc[0]][0]
                    subpop_above = tp_next.popsizes[desc[1]][0]
                    # Determine correct connection location
                    connect_below -= direc * subpop_below
                    connect_above += direc * subpop_above
                # Single population case
                else:
                    connect_below = tp_next.origins[desc][1]
                    subpop = tp_next.popsizes[desc][0]
                    connect_above = connect_below + direc * subpop
                # Draw the connections
                tsize = int(transition_size * model.precision)
                cx = tp.time[-1 * tsize:]
                cy_below_1 = [origin[1]] * tsize
                cy_above_1 = origin[1] + direc * popsize[-1 * tsize:]
                cy_below_2 = np.linspace(cy_below_1[0], connect_below, tsize)
                cy_above_2 = np.linspace(cy_above_1[0], connect_above, tsize)
                ax.fill_between(cx,
                                cy_below_1,
                                cy_below_2,
                                color=pop_color,
                                edgecolor=pop_color)
                ax.fill_between(cx,
                                cy_above_1,
                                cy_above_2,
                                color=pop_color,
                                edgecolor=pop_color)

            # Draw migrations if necessary
            if draw_migrations and tp.migrations is not None:
                # Iterate through migrations for current population
                for mig_index, mig_val in enumerate(tp.migrations[pop_index]):
                    # If no migration, continue
                    if mig_val == 0:
                        continue
                    # Calculate proper offset for arrow within this period
                    num_migrations += 1
                    offset = int(tp.precision * num_migrations /
                                 (total_migrations + 1.0))
                    x = tp.time[offset]
                    dx = 0
                    # Determine which sides of populations are closest
                    y1 = origin[1]
                    y2 = y1 + direc * popsize[offset]
                    mig_y1 = tp.origins[mig_index][1]
                    mig_y2 = mig_y1 + (tp.direcs[mig_index] *
                                       tp.popsizes[mig_index][offset])
                    y = y1 if abs(mig_y1 - y1) < abs(mig_y1 - y2) else y2
                    dy = mig_y1-y if abs(mig_y1 - y) < abs(mig_y2 - y) \
                                  else mig_y2-y
                    # Scale arrow to proper size
                    mig_scale = max(0.1, mig_val / mig_max)
                    awidth = xmax * arrow_size * mig_scale
                    alength = ymax * arrow_size
                    ax.arrow(x,
                             y,
                             dx,
                             dy,
                             width=awidth,
                             head_width=awidth * 3,
                             head_length=alength,
                             color=arrow_color,
                             length_includes_head=True)

    # Label populations if proper labels are given
    tp_last = model.tp_list[-1]
    if pop_labels and len(pop_labels) == len(tp_last.popsizes):
        ax2 = ax.twinx()
        ax2.set_xlim(ax.get_xlim())
        ax2.set_ylim(ax.get_ylim())
        # Determine placement of ticks
        yticks = [
            tp_last.origins[i][1] +
            0.5 * tp_last.direcs[i] * tp_last.popsizes[i][-1]
            for i in range(len(tp_last.popsizes))
        ]
        ax2.yaxis.set_major_locator(mticker.FixedLocator(yticks))
        ax2.set_yticklabels(pop_labels)
        ax2.tick_params(which='both',
                        color='none',
                        labelcolor=text_color,
                        labelsize=label_size,
                        left=False,
                        top=False,
                        right=False)
        ax2.spines['top'].set_color(text_color)
        ax2.spines['left'].set_color(text_color)
        ax2.spines['right'].set_color(text_color)
        ax2.spines['bottom'].set_color(text_color)

    # Display figure
    if save_file:
        plt.savefig(save_file, **fig_kwargs)
    else:
        if show == True:
            plt.show()
    if ax == None:
        plt.close(fig)