Esempio n. 1
0
def plot_bar_graph(series,
                   series_colors,
                   series_labels=None,
                   series_color_emphasis=None,
                   series_errs=None,
                   series_err_colors=None,
                   series_padding=0.0,
                   series_use_labels=False,
                   series_style=None,
                   plot_xlabel=None,
                   plot_ylabel=None,
                   plot_yinvert=False,
                   plot_title=None,
                   category_labels=None,
                   category_ticks=True,
                   category_padding=0.25,
                   barwidth=0.35,
                   xpadding=0.,
                   stacked=False,
                   fontsize=8,
                   legend_fontsize=8,
                   legend_location='best',
                   savefile=None,
                   savefile_size=(3.4, 1.5),
                   horizontal=False,
                   show_plot=True):
    """
    Plot a bar graph 
    @param series List of data for each series - each of these will be plotted in a different color
      Each series should have the same number of elements. 
    @param series_labels List of labels for each series - same length as series
    @param series_colors List of colors for each series - same length as series
    @param series_color_emphasis List of booleans, one for each series, indicating whether the series
       color should be bold - if None no series is bold
    @param series_errs The error values for each series - if None no error bars are plotted
    @param series_err_colors The colors for the error bars for each series, if None black is used
    @param plot_xlabel The label for the x-axis - if None no label is printed
    @param plot_ylabel The label for the y-axis - if None no label is printed
    @param plot_title A title for the plot - if None no title is printed
    @param category_labels The labels for each particular category in the histogram
    @param category_ticks If true, also place a tick at each category
    @param category_padding Fraction of barwidth (0 - 1) - distance between categories
    @param barwidth The width of each bar
    @param xpadding The padding between the first bar and the left axis and the last bar and the right axis
    @param stacked If true, stack the data in each category
    @param fontsize The size of font for all labels
    @param legend_fontsize The size of font for the legend labels
    @param legend_location The location of the legend, if None no legend is included
    @param savefile The path to save the plot to, if None plot is not saved
    @param savefile_size The size of the saved plot
    @param horizontal Plot bars horizontally instead of vertically
    @param show_plot If True, display the plot on the screen via a call to plt.show()
    @return fig, ax The figure and axis the plot was created in
    """

    # Validate
    if series is None or len(series) == 0:
        raise ValueError('No data series')
    num_series = len(series)

    if len(series_colors) != num_series:
        raise ValueError('You must define a color for every series')

    if series_labels is None:
        series_labels = [None for l in range(num_series)]
    if len(series_labels) != num_series:
        raise ValueError('You must define a label for every series')

    if series_errs is None:
        series_errs = [None for _ in series]
    if len(series_errs) != num_series:
        raise ValueError(
            'series_errs is not None. Must provide error value for every series.'
        )

    if series_err_colors is None:
        series_err_colors = ['black' for _ in series]
    if len(series_err_colors) != num_series:
        raise ValueError('Must provide an error bar color for every series')

    if series_color_emphasis is None:
        series_color_emphasis = [False for _ in series]
    if len(series_color_emphasis) != num_series:
        raise ValueError(
            'The emphasis list must be the same length as the series_colors list'
        )

    if series_use_labels and len(series[0]) > 1:
        raise ValueError('Only series containing one category may be labeled.')

    if series_style is None:
        series_style = [dict() for _ in series]

    fig, ax = plt.subplots()
    plot_utils.configure_fonts(fontsize=fontsize,
                               legend_fontsize=legend_fontsize)

    if plot_yinvert:
        ax.invert_yaxis()

    spacing = category_padding * barwidth
    num_categories = len(series[0])
    category_width = spacing + num_series * barwidth
    if stacked:
        category_width = spacing + barwidth
    index = numpy.array(range(num_categories)) * category_width + xpadding

    for idx in range(num_series):
        offset = idx * (barwidth + series_padding)
        if stacked:
            offset = 0.

        style = dict(
            label=series_labels[idx],
            color=colors.get_plot_color(series_colors[idx],
                                        emphasis=series_color_emphasis[idx]),
            ecolor=colors.get_plot_color(series_err_colors[idx]),
            linewidth=0,
        )
        style.update(series_style[idx])

        if horizontal:
            r = ax.barh(
                bottom=index + offset,
                #left = index + offset,
                height=barwidth,  #series[idx],
                width=series[idx],  # barwidth,
                xerr=series_errs[idx],
                **style)
        else:
            r = ax.bar(left=index + offset,
                       height=series[idx],
                       width=barwidth,
                       yerr=series_errs[idx],
                       **style)

    # Label the plot
    if plot_ylabel is not None:
        ax.set_ylabel(plot_ylabel)
    if plot_xlabel is not None:
        ax.set_xlabel(plot_xlabel)
    if plot_title is not None:
        ax.set_title(plot_title)

    # Add category tick marks
    if series_use_labels:
        indices = numpy.arange(num_series)
        ticks = xpadding + indices * (barwidth +
                                      series_padding) + 0.5 * barwidth
        if horizontal:
            ax.set_yticks(ticks)
        else:
            ax.set_xticks(ticks)

        if series_labels is not None:
            if horizontal:
                ax.set_yticklabels(series_labels)
            else:
                ax.set_xticklabels(series_labels)
    elif category_ticks:
        if not stacked:
            ticks = index + (num_series / 2.) * barwidth
        else:
            ticks = index + .5 * barwidth

        if horizontal:
            ax.set_yticks(ticks)
        else:
            ax.set_xticks(ticks)

        if category_labels is not None:
            if horizontal:
                ax.set_yticklabels(category_labels)
            else:
                ax.set_xticklabels(category_labels)
    else:
        if horizontal:
            ax.set_yticks([])
        else:
            ax.set_xticks([])

    # Set the x-axis limits
    lims = [
        0, 2 * xpadding + num_categories *
        (category_width + (num_series - 1) * series_padding) - spacing
    ]
    if horizontal:
        ax.set_ylim(lims)
    else:
        ax.set_xlim(lims)

    # Legend
    if legend_location is not None:
        ax.legend(loc=legend_location, frameon=False)

    # Make the axis pretty
    plot_utils.simplify_axis(ax)

    # Save the file
    if savefile is not None:
        plot_utils.output(fig,
                          savefile,
                          savefile_size,
                          fontsize=fontsize,
                          legend_fontsize=legend_fontsize)

    # Show
    if show_plot:
        plt.show()

    return fig, ax
Esempio n. 2
0
def plot(series,
         series_colors,
         series_color_emphasis=None,
         series_labels=None,
         series_errs=None,
         series_err_colors=None,
         fill_error=True,
         plot_xlabel=None,
         plot_xlim=None,
         plot_ylabel=None,
         plot_ylim=None,
         plot_title=None,
         fontsize=8,
         legend_fontsize=8,
         linewidth=2,
         legend_location='best',
         savefile=None,
         savefile_size=(3.4, 1.5),
         show_plot=True,
         jitter_x=-1.,
         jitter_y=-1.,
         jitter_alpha=1.,
         marker_size=5,
         x_scale='linear',
         y_scale='linear',
         plot_markers=None,
         line_styles=None,
         mark_every=1):
    """
    Plot yvals as a function of xvals
    @param series List of (xvals, yvals) for each series- each of these will be plotted in a different color
    @param series_labels List of labels for each series - same length as series
    @param series_colors List of colors for each series - same length as series
    @param series_color_emphasis List of booleans, one for each series, indicating whether the series
       color should be bold - if None no series is bold
    @param series_errs The error values for each series - if None no error bars are plotted
    @param series_err_colors The colors for the error bars for each series, if None black is used
-   @param fill_error If true, shade the area to show error bars, if not, draw actual error bars
    @param plot_xlabel The label for the x-axis - if None no label is printed
    @param plot_xlim The limits for the x-axis - if None these limits are selected automatically
    @param plot_ylabel The label for the y-axis - if None no label is printed
    @param plot_ylim The limits for the y-axis - if None these limits are selected automatically
    @param plot_title A title for the plot - if None no title is printed
    @param fontsize The size of font for all labels
    @param legend_fontsize The size of font for the legend labels
    @param linewidth The width of the lines in the plot
    @param legend_location The location of the legend, if None no legend is included
    @param savefile The path to save the plot to, if None plot is not saved
    @param savefile_size The size of the saved plot
    @param show_plot If True, display the plot on the screen via a call to plt.show()
    @param x_scale set to log or linear for the x axis
    @param y_scale set to log or linear for the y axis
    @param plot_markers if not None, a list of line marker symbols
    @param line_styles if not None, a list of line style symbols
    
    @param marker_size 
    @param jitter_x the scale of the normal distribution for jitter along x direction, if -1., then no jitter along x 
    @param jitter_y the scale of the normal distribution for jitter along y direction, if -1., then no jitter along y 
    @param jitter_alpha the alpha value for jittered scatter points
    @return fig, ax The figure and axis the plot was created in
    """

    # Validate
    if series is None or len(series) == 0:
        raise ValueError('No data series')
    num_series = len(series)

    if len(series_colors) != num_series:
        raise ValueError('You must define a color for every series')

    if series_labels is None:
        series_labels = [None for s in series]
    if len(series_labels) != num_series:
        raise ValueError('You must define a label for every series')

    if series_errs is None:
        series_errs = [None for s in series]
    if len(series_errs) != num_series:
        raise ValueError(
            'series_errs is not None. Must provide error value for every series.'
        )

    if series_err_colors is None:
        series_err_colors = ['black' for s in series]
    if len(series_err_colors) != num_series:
        raise ValueError('Must provide an error bar color for every series')

    if series_color_emphasis is None:
        series_color_emphasis = [False for s in series]
    if len(series_color_emphasis) != num_series:
        raise ValueError(
            'The emphasis list must be the same length as the series_colors list'
        )

    if plot_markers is None:
        plot_markers = ['None' for s in series]
    if len(plot_markers) != num_series:
        raise ValueError('The marker list must contain all series')

    if line_styles is None:
        line_styles = ['-' for s in series]
    if len(line_styles) != num_series:
        raise ValueError('The line style list must contain all series')

    fig, ax = plt.subplots()
    plot_utils.configure_fonts(fontsize=fontsize,
                               legend_fontsize=legend_fontsize)

    num_series = len(series)
    for idx in range(num_series):
        xvals = series[idx][0]
        yvals = numpy.array(series[idx][1])

        # by Shen Li
        if jitter_x != -1. and jitter_y == -1.:
            # http://nbviewer.jupyter.org/gist/fonnesbeck/5850463
            # Add some random "jitter" to the x-axis
            xvals = numpy.random.normal(xvals, jitter_x, size=len(yvals))
        elif jitter_x == -1. and jitter_y != -1.:
            # http://nbviewer.jupyter.org/gist/fonnesbeck/5850463
            # Add some random "jitter" to the x-axis
            yvals = numpy.random.normal(yvals, jitter_y, size=len(xvals))
        elif jitter_x != -1. and jitter_y != -1.:
            xs = list(numpy.random.normal(xvals, jitter_x, size=len(yvals)))
            ys = list(numpy.random.normal(yvals, jitter_y, size=len(xvals)))
            xvals = xs
            yvals = ys

        r = ax.plot(xvals,
                    yvals,
                    label=series_labels[idx],
                    color=colors.get_plot_color(
                        series_colors[idx],
                        emphasis=series_color_emphasis[idx]),
                    marker=plot_markers[idx],
                    linestyle=line_styles[idx],
                    lw=linewidth,
                    markersize=marker_size,
                    markevery=mark_every,
                    alpha=jitter_alpha)

        errs = series_errs[idx]
        if errs is not None:
            shade_color = colors.get_plot_color(color=series_err_colors[idx])
            if fill_error:
                plot_utils.shaded_error(ax,
                                        xvals,
                                        yvals,
                                        numpy.array(errs),
                                        color=shade_color)
            else:
                ax.errorbar(xvals,
                            yvals,
                            yerr=errs,
                            linestyle='None',
                            ecolor=shade_color)

    # Label the plot
    if plot_ylabel is not None:
        ax.set_ylabel(plot_ylabel)
    if plot_xlabel is not None:
        ax.set_xlabel(plot_xlabel)
    if plot_title is not None:
        ax.set_title(plot_title)
    if plot_xlim is not None:
        ax.set_xlim(plot_xlim)
    if plot_ylim is not None:
        ax.set_ylim(plot_ylim)

    ax.set_yscale(y_scale)
    ax.set_xscale(x_scale)

    # Legend
    if legend_location is not None:
        ax.legend(loc=legend_location, frameon=False)

    # Make the axis pretty
    plot_utils.simplify_axis(ax)

    # Save the file
    if savefile is not None:
        plot_utils.output(fig,
                          savefile,
                          savefile_size,
                          fontsize=fontsize,
                          legend_fontsize=legend_fontsize)

    # Show
    if show_plot:
        plt.show()

    return fig, ax
Esempio n. 3
0
def plot_hexbin(series,
         series_labels = None,
         plot_xlabel = None, 
         plot_xlim = None,
         plot_ylabel = None, 
         plot_ylim = None,
         plot_title = None,
         fontsize=8,
         vmax = None,
         vmin = None,
         linewidths = None,
         savefile = None,
         savefile_size = (3.4, 1.5),
         show_plot = True,
         gridsize = 20,
         series_C = None,
         reduce_C_function = numpy.mean,
         color_map = cm.jet,
         x_scale = 'linear',
         y_scale = 'linear',
         mincnt = None,
         series_edgecolors = None,
         marginals = False):
    """
    Plot yvals as a function of xvals
    @param series List of (xvals, yvals) for each series- each of these will be plotted in a different color
    @param plot_xlabel The label for the x-axis - if None no label is printed
    @param plot_xlim The limits for the x-axis - if None these limits are selected automatically
    @param plot_ylabel The label for the y-axis - if None no label is printed
    @param plot_ylim The limits for the y-axis - if None these limits are selected automatically
    @param plot_title A title for the plot - if None no title is printed
    @param fontsize The size of font for all labels
    @param linewidths The width of the lines in the plot
    @param savefile The path to save the plot to, if None plot is not saved
    @param savefile_size The size of the saved plot
    @param show_plot If True, display the plot on the screen via a call to plt.show()
    
    @param x_scale set to log or linear for the x axis
    @param y_scale set to log or linear for the y axis
    @param cmap the color of hexbin
    @param series_C In hexbin, 'C' is optional--it maps values to x-y coordinates; 
                    if 'C' is None (default) then the result is a pure 2D histogram
    @param reduce_C_function    
        Could be numpy.mean, .sum, .count_nonzero, .max
        Function of one argument that reduces all the values in a bin to a single number
    @param mincnt If not None, only display cells with more than mincnt number of points in the cell
    @param gridsize how many hexagons in one line
    @return fig, ax The figure and axis the plot was created in
    """

    # Validate
    if series is None or len(series) == 0:
        raise ValueError('No data series')
    num_series = len(series)

    if series_labels is None:
        series_labels = [None for s in series]
    if len(series_labels) != num_series:
        raise ValueError('You must define a label for every series')

    fig, ax = plt.subplots()
    plot_utils.configure_fonts(fontsize=fontsize)

    xvals = series[0]
    yvals = numpy.array(series[1])
    if series_C == None:
        c = None
    else:
        c = series_C
    if series_edgecolors == None:
        edgecolors = None
    else:
        edgecolors = colors.get_plot_color(series_edgecolors)


    #http://stackoverflow.com/questions/2369492/generate-a-heatmap-in-matplotlib-using-a-scatter-data-set/2371812#2371812
    # http://matplotlib.org/examples/color/colormaps_reference.html
    # if 'bins=None', then color of each hexagon corresponds directly to its count
    # 'C' is optional--it maps values to x-y coordinates; if 'C' is None (default) then
    # the result is a pure 2D histogram
    # hb = plt.hexbin(x, y, C=None, gridsize=gridsize, marginals=True, cmap=cm.jet, bins=None)
    # hb = plt.hexbin(x, y, C=None, marginals=True, cmap=plt.get_cmap('YlOrBr'), bins=None)
    # r = ax.hexbin(xvals, yvals, C=None, marginals=True, cmap=plt.get_cmap('YlOrBr'), bins=None)


    if plot_xlim is None or plot_ylim is None:
        plot_xlim = [min(xvals),max(xvals)]
        plot_ylim = [min(yvals),max(yvals)]
    if c == None:
        im = ax.hexbin(xvals, yvals, 
            C=None,
            cmap=color_map,
            bins=None,
            gridsize = gridsize,
            xscale = 'linear', 
            yscale = 'linear',
            norm=None, 
            vmin=vmin, 
            vmax=vmax,
            alpha=None, 
            linewidths=linewidths, 
            edgecolors=edgecolors,
            extent=[plot_xlim[0], plot_xlim[1], plot_ylim[0], plot_ylim[1]],
            mincnt=None, 
            marginals=marginals)
    else:
        im = ax.hexbin(xvals, yvals, 
            C=c,
            cmap=color_map,
            bins=None,
            gridsize = gridsize,
            xscale = 'linear', 
            yscale = 'linear',
            norm=None, 
            vmin=vmin, 
            vmax=vmax,
            alpha=None, 
            linewidths=linewidths, 
            edgecolors=edgecolors,
            reduce_C_function = reduce_C_function, 
            mincnt=None, 
            marginals=marginals)

    cb = fig.colorbar(im,ax=ax)
    cb.set_label('count')

    # plt.axis([min(x)-1., max(x)+1., min(y)-0.1, max(y)+0.1])

    plt.axis([plot_xlim[0], plot_xlim[1], plot_ylim[0], plot_ylim[1]])
    
    if plot_ylabel is not None:
        plt.ylabel(plot_ylabel)
    if plot_xlabel is not None:
        plt.xlabel(plot_xlabel)
    if plot_title is not None:
        plt.title(plot_title)


    # Make the axis pretty
    plot_utils.simplify_axis(ax)

    # Save the file
    if savefile is not None:
        plot_utils.output(fig, savefile, savefile_size,
                          fontsize=fontsize)
    # Show
    if show_plot:
        plt.show()

    return im
Esempio n. 4
0
    plot_labels = [f[1] for f in features]
    from ss_plotting.make_plots import plot_bar_graph
    fig, ax = plot_bar_graph(series,
                             series_labels=series_labels,
                             series_errs=series_errs,
                             series_colors=['grey'],
                             category_labels=[f[1] for f in features],
                             category_rotation=45,
                             xpadding=0.3,
                             show_plot=False,
                             simplify=False)
    ax.set_ylim([0.70, 0.85])
    xlim = ax.get_xlim()
    ax.set_xlim([xlim[0], xlim[1] - 0.5])
    from ss_plotting import plot_utils
    plot_utils.simplify_axis(ax)
    plt.savefig("Feature_comparison.eps", format='eps', dpi=1000)
    import IPython
    IPython.embed()

    # ALl,  3F + RoC, 3F/P + RoC, 3F/T + ROC, Z-force + RoC

# def tapo_plot_confusion_matrix(cm, classes, modality, path, normalize=False, title='Confusion matrix', cmap=pp.cm.Blues):
#     """ This function prints and plots the confusion matrix.
#         Normalization can be applied by setting `normalize=True`.
#     """
#     if normalize:
#         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
#         print("Normalized confusion matrix")
#     else:
#         cm = cm.astype('int')
Esempio n. 5
0
def plot_stacked_vertical_bar_graph(series,
                                    series_colors,
                                    series_labels=None,
                                    series_color_emphasis=None,
                                    series_errs=None,
                                    series_err_colors=None,
                                    series_padding=0.0,
                                    series_use_labels=False,
                                    series_style=None,
                                    plot_xlabel=None,
                                    plot_ylabel=None,
                                    plot_yinvert=False,
                                    plot_title=None,
                                    category_labels=None,
                                    category_ticks=True,
                                    category_padding=0.25,
                                    barwidth=0.35,
                                    xpadding=0.,
                                    fontsize=8,
                                    legend_fontsize=8,
                                    legend_location='best',
                                    legend_title=None,
                                    legend_reverse_label=False,
                                    legend_bbox_to_anchor=None,
                                    legend_pos_rel_plot=None,
                                    legend_title_pos=None,
                                    legend_labelspacing=None,
                                    savefile=None,
                                    savefile_size=(3.4, 1.5),
                                    show_plot=True):
    """
    Plot a bar graph 
    @param series List of data for each series - each of these will be plotted in a different color
      Each series should have the same number of elements. 
    @param series_labels List of labels for each series - same length as series
    @param series_colors List of colors for each series - same length as series
    @param series_color_emphasis List of booleans, one for each series, indicating whether the series
       color should be bold - if None no series is bold
    @param series_errs The error values for each series - if None no error bars are plotted
    @param series_err_colors The colors for the error bars for each series, if None black is used
    @param plot_xlabel The label for the x-axis - if None no label is printed
    @param plot_ylabel The label for the y-axis - if None no label is printed
    @param plot_title A title for the plot - if None no title is printed
    @param category_labels The labels for each particular category in the histogram
    @param category_ticks If true, also place a tick at each category
    @param category_padding Fraction of barwidth (0 - 1) - distance between categories
    @param barwidth The width of each bar
    @param xpadding The padding between the first bar and the left axis and the last bar and the right axis
    @param fontsize The size of font for all labels
    @param legend_fontsize The size of font for the legend labels
    @param legend_location The location of the legend, if None no legend is included
    @param savefile The path to save the plot to, if None plot is not saved
    @param savefile_size The size of the saved plot
    @param show_plot If True, display the plot on the screen via a call to plt.show()
    @return fig, ax The figure and axis the plot was created in
    """

    # Validate
    if series is None or len(series) == 0:
        raise ValueError('No data series')
    num_series = len(series)

    if len(series_colors) != num_series:
        raise ValueError('You must define a color for every series')

    if series_labels is None:
        series_labels = [None for l in range(num_series)]
    if len(series_labels) != num_series:
        raise ValueError('You must define a label for every series')

    if series_errs is None:
        series_errs = [None for _ in series]
    if len(series_errs) != num_series:
        raise ValueError(
            'series_errs is not None. Must provide error value for every series.'
        )

    if series_err_colors is None:
        series_err_colors = ['black' for _ in series]
    if len(series_err_colors) != num_series:
        raise ValueError('Must provide an error bar color for every series')

    if series_color_emphasis is None:
        series_color_emphasis = [False for _ in series]
    if len(series_color_emphasis) != num_series:
        raise ValueError(
            'The emphasis list must be the same length as the series_colors list'
        )

    if series_use_labels and len(series[0]) > 1:
        raise ValueError('Only series containing one category may be labeled.')

    if series_style is None:
        series_style = [dict() for _ in series]

    fig, ax = plt.subplots()
    plot_utils.configure_fonts(fontsize=fontsize,
                               legend_fontsize=legend_fontsize)

    if plot_yinvert:
        ax.invert_yaxis()

    spacing = category_padding * barwidth
    num_categories = len(series[0])
    category_width = spacing + barwidth
    # position for each bar
    index = numpy.array(range(num_categories)) * category_width + xpadding
    # print index

    y_offset = numpy.array([0.0] * num_categories)

    for idx in range(num_series):
        # x_offset = idx * (barwidth + series_padding)
        x_offset = 0.

        style = dict(
            label=series_labels[idx],
            color=colors.get_plot_color(series_colors[idx],
                                        emphasis=series_color_emphasis[idx]),
            ecolor=colors.get_plot_color(series_err_colors[idx]),
            linewidth=0,
        )
        style.update(series_style[idx])

        r = ax.bar(left=index + x_offset,
                   height=series[idx],
                   width=barwidth,
                   bottom=y_offset,
                   yerr=series_errs[idx],
                   **style)
        y_offset = y_offset + series[idx]

    # Label the plot
    if plot_ylabel is not None:
        ax.set_ylabel(plot_ylabel)
    if plot_xlabel is not None:
        ax.set_xlabel(plot_xlabel)
    if plot_title is not None:
        ax.set_title(plot_title)

    # Add category tick marks
    if series_use_labels:
        indices = numpy.arange(num_series)
        ticks = xpadding + indices * (barwidth +
                                      series_padding) + 0.5 * barwidth

        ax.set_xticks(ticks)

        if series_labels is not None:
            ax.set_xticklabels(series_labels)
    elif category_ticks:
        ticks = index + .5 * barwidth

        ax.set_xticks(ticks)

        if category_labels is not None:
            ax.set_xticklabels(category_labels)
    else:
        ax.set_xticks([])

    # Set the x-axis limits
    lims = [
        0, 2 * xpadding + num_categories *
        (category_width + (num_series - 1) * series_padding) - spacing
    ]
    ax.set_xlim(lims)

    # Legend
    if legend_location is not None:

        if legend_pos_rel_plot == 'right':
            # http://stackoverflow.com/questions/4700614/how-to-put-the-legend-out-of-the-plot
            # Shrink current axis by 20%
            box = ax.get_position()
            ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
            legend_location = 'center left'
            legend_bbox_to_anchor = (1, 0.5)
        # # Put a legend to the right of the current axis
        # ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

        # http://matplotlib.org/1.3.1/users/legend_guide.html
        l = ax.legend(title=legend_title, labelspacing=legend_labelspacing,\
            frameon=False, loc=legend_location, bbox_to_anchor=legend_bbox_to_anchor)
        # http://matplotlib.org/1.3.1/users/legend_guide.html
        if legend_reverse_label:
            handles, labels = ax.get_legend_handles_labels()
            l = ax.legend(handles[::-1], labels[::-1],\
                title=legend_title, labelspacing=legend_labelspacing,\
                frameon=False,loc=legend_location,\
                bbox_to_anchor=legend_bbox_to_anchor)
        if legend_title_pos is not None:
            l.get_title().set_position(legend_title_pos)

    # Make the axis pretty
    plot_utils.simplify_axis(ax)

    # Save the file
    if savefile is not None:
        plot_utils.output(fig,
                          savefile,
                          savefile_size,
                          fontsize=fontsize,
                          legend_fontsize=legend_fontsize)

    # Show
    if show_plot:
        plt.show()

    return fig, ax