コード例 #1
0
def plotConvergence(simId, expName, paramName, values, show=True, save=False):
    '''
    Examine the first 3 moments (mean, std, skewness) in the data set
    for increasing number (N) of values, growing by the given increment.
    Optionally plot the relationship between each of the moments and N,
    so we can when (if) convergence occurs.
    '''
    _logger.debug("Generating convergence plots...")
    count = values.count()
    results = {'Mean': [], 'Stdev': [], 'Skewness': [], '95% CI': []}

    increment = min(100, count // 20)
    nValues = list(range(increment, count + increment - 1, increment))

    for N in nValues:
        sublist = values[:N]
        results['Mean'].append(sublist.mean())
        results['Stdev'].append(sublist.std())
        results['Skewness'].append(sublist.skew())

        ciLow  = np.percentile(sublist, 2.5)
        ciHigh = np.percentile(sublist, 97.5)
        results['95% CI'].append(ciHigh - ciLow)

    # Insert zero value at position 0 for all lists to ensure proper scaling
    nValues.insert(0,0)
    for dataList in results.values():
        dataList.insert(0,0)

    labelsize=12
    for key, values in iteritems(results):
        plt.clf()   # clear previous figure
        ax = plt.gca()
        ax.tick_params(axis='x', labelsize=labelsize)
        ax.tick_params(axis='y', labelsize=labelsize)
        plt.plot(nValues, results[key])
        plt.title("%s" % paramName, size='large')
        ax.yaxis.grid(False)
        ax.xaxis.grid(True)
        plt.xlabel('Trials', size='large')
        plt.ylabel(key, size='large')
        plt.figtext(0.12, 0.02, "SimId=%d, Exp=%s" % (simId, expName),
                    color='black', weight='roman', size='x-small')

        if save:
            filename = makePlotPath("%s-s%02d-%s-%s" % (expName, simId, paramName, key), simId)
            _logger.debug("Saving convergence plot to %s" % filename)
            plt.savefig(filename)

        if show:
            plt.show()

    fig = plt.gcf()
    plt.close(fig)
コード例 #2
0
def plotForcingSubplots(tsdata, filename=None, ci=95, show_figure=False, save_fig_kwargs=None):
    sns.set_context('paper')
    expList = tsdata['expName'].unique()

    nrows = 1
    ncols = len(expList)
    width  = 2 * ncols
    height = 2
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharey=True, figsize=(width, height))

    def dataForExp(expName):
        df = tsdata.query("expName == '%s'" % expName).copy()
        df.drop(['expName'], axis=1, inplace=True)
        df = pd.melt(df, id_vars=['runId'], var_name='year')
        return df

    for ax, expName in zip(axes, expList):
        df = dataForExp(expName)

        pos = expName.find('-')
        title = expName[:pos] if pos >= 0 else expName
        ax.set_title(title.capitalize())

        tsm.tsplot(df, time='year', unit='runId', value='value', ci=ci, ax=ax)

        ylabel = 'W m$^{-2}$' if ax == axes[0] else ''
        ax.set_ylabel(ylabel)
        ax.set_xlabel('') # no need to say "year"
        ax.axhline(0, color='navy', linewidth=0.5, linestyle='-')
        plt.setp(ax.get_xticklabels(), rotation=270)

    plt.tight_layout()

    # Save the file
    if filename:
        if isinstance(save_fig_kwargs, dict):
            fig.savefig(filename, **save_fig_kwargs)
        else:
            fig.savefig(filename)

    # Display the figure
    if show_figure:
        plt.show()

    return fig
コード例 #3
0
    def plotParallelCoordinates(self, inputDF, resultSeries, numInputs=None,
                                filename=None, extra=None, inputBins=None,
                                outputLabels=['Low', 'Medium', 'High'],
                                quantiles=False, normalize=True, invert=False,
                                show=False, title=None, rotation=None):
        '''
        Plot a parallel coordinates figure.

        :param inputDF: (pandas.DataFrame) trial inputs
        :param resultSeries: (pandas.Series) results to categorize lines
        :param numInputs: (int) the number of inputs to plot, choosing these
           from the most-highly correlated (or anti-correlated) to the lowest.
           If not provided, all variables in `inputDF` are plotted.
        :param filename: (str) name of graphic file to create
        :param extra: (str) text to draw down the right side, labeling the figure
        :param inputBins: (int) the number of bins to use to quantize inputs
        :param quantiles: (bool) create bins with equal numbers of values rather than
           bins of equal boundary widths. (In pandas terms, use qcut rather than cut.)
        :param normalize: (bool) normalize values to percentages of the range for each var.
        :param invert: (bool) Plot negatively correlated values as (1 - x) rather than (x).
        :param outputLabels: (list of str) labels to assign to outputs (and thus the number
           of bins to group the outputs into.)
        :param title: (str) Figure title
        :param show: (bool) If True, show the figure.
        :return: none
        '''
        from pandas.plotting import parallel_coordinates

        corrDF = getCorrDF(inputDF, resultSeries)
        numInputs = numInputs or len(corrDF)
        cols = list(corrDF.index[:numInputs])

        # isolate the top-correlated columns
        inputDF = inputDF[cols]

        # trim down to trials with result (in case of failures)
        inputDF = inputDF.ix[resultSeries.index]

        if normalize or invert:
            inputDF = normalizeDF(inputDF)

        if invert:
            for name in cols:
                # flip neg. correlated values to reduce line crossings
                if corrDF.spearman[name] < 0:
                    inputDF[name] = 1 - inputDF[name]
                    inputDF.rename(columns={name: "(1 - %s)" % name}, inplace=True)
            cols = inputDF.columns

        # optionally quantize inputs into the given number of bins
        plotDF = binColumns(inputDF, bins=inputBins) if inputBins else inputDF.copy()

        # split results into equal-size or equal-quantile bins
        cutFunc = pd.qcut if quantiles else pd.cut
        plotDF['category'] = cutFunc(resultSeries, len(outputLabels), labels=outputLabels)

        colormap = 'rainbow'
        alpha = 0.4

        # color = [
        #     [0.8, 0.0, 0.1, alpha],
        #     [0.0, 0.8, 0.1, alpha],
        #     [0.1, 0.1, 0.8, alpha],
        # ]
        parallel_coordinates(plotDF, 'category', cols=cols, alpha=alpha,
                             #color=color,
                             colormap=colormap,
                             )
        fig = plt.gcf()
        fig.canvas.draw()       # so that ticks / labels are generated

        if rotation is not None:
            plt.xticks(rotation=rotation)

        # Labels can come out as follows for, say, 4 bins:
        # [u'', u'0.0', u'0.5', u'1.0', u'1.5', u'2.0', u'2.5', u'3.0', u'']
        # We eliminate the "x.5" labels by substituting '' and convert the remaining
        # numerical values to integers (i.e., eliminating ".0")
        def _fixTick(text):
            if inputBins:
                return '' if (not text or text.endswith('.5')) else str(int(float(text)))

            # If not binning, just show values on Y-axis
            return text

        locs, ylabels = plt.yticks()
        ylabels = [_fixTick(t._text) for t in ylabels]
        plt.yticks(locs, ylabels)

        if extra:
            printExtraText(fig, extra, loc='top', color='lightgrey', weight='ultralight', fontsize='xx-small')

        plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
        if title:
            plt.title(title)

        if show:
            plt.show()

        if filename:
            _logger.debug("Saving parallel coordinates plot to %s" % filename)
            plt.savefig(filename, bbox_inches='tight')

        plt.close(fig)
コード例 #4
0
def plotHistogram(values, xlabel=None, ylabel=None, title=None, xmin=None, xmax=None,
                  extra=None, extraColor='grey', extraLoc='right',
                  hist=True, showCI=False, showMean=False, showMedian=False,
                  color=None, shade=False, kde=True, show=True, filename=None):

    fig = plt.figure()

    style    = "white"
    colorSet = "Set1"
    sns.set_style(style)
    sns.set_palette(colorSet, desat=0.6)
    red, blue, green, purple = sns.color_palette(colorSet, n_colors=4)

    color = blue if color is None else color
    count = values.count()
    bins  = count // 10 if count > 150 else (count // 5 if count > 50 else (count // 2 if count > 20 else None))
    sns.distplot(values, hist=hist, bins=bins, kde=kde, color=color, kde_kws={'shade': shade})

    #sns.axlabel(xlabel=xlabel, ylabel=ylabel)
    if xlabel:
        plt.xlabel(xlabel) # , size='large')
    if ylabel:
        plt.ylabel(ylabel) # , size='large')

    sns.despine()

    if title:
        t = plt.title(title)
        t.set_y(1.02)

    printExtraText(fig, extra, color=extraColor, loc=extraLoc)

    if xmin is not None or xmax is not None:
        ax = plt.gca()
        ax.set_autoscale_on(False)
        ax.set_xlim(xmin, xmax)

    if showCI or showMean:
        ymin, ymax = plt.ylim()
        xmin, xmax = plt.xlim()
        textSize = 9
        labely   = ymax * 0.95
        deltax   = (xmax-xmin) * 0.01

        if showCI:
            color = red
            ciLow  = np.percentile(values, 2.5)
            ciHigh = np.percentile(values, 97.5)
            plt.axvline(ciLow,  color=color, linestyle='solid', linewidth=2)
            plt.axvline(ciHigh, color=color, linestyle='solid', linewidth=2)
            plt.text(ciLow  + deltax, labely, '2.5%%=%.2f'  % ciLow,  size=textSize, rotation=90, color=color)
            plt.text(ciHigh + deltax, labely, '97.5%%=%.2f' % ciHigh, size=textSize, rotation=90, color=color)

        if showMean:
            color = green
            mean = np.mean(values)
            plt.axvline(mean, color=color, linestyle='solid', linewidth=2)
            plt.text(mean + deltax, labely, 'mean=%.2f' % mean, color=color, size=textSize, rotation=90)

        if showMedian:
            color = purple
            median = np.percentile(values, 50)
            labely = ymax * 0.50
            plt.axvline(median, color=color, linestyle='solid', linewidth=2)
            plt.text(median + deltax, labely, 'median=%.2f' % median, color=color, size=textSize, rotation=90)

    if show:
        plt.show()

    if filename:
        _logger.info("plotHistogram writing to: %s", filename)
        fig.savefig(filename)

    plt.close(fig)
コード例 #5
0
def plotTornado(data, colname='value', labelsize=9, title=None, color=None, height=0.8,
                maxVars=DEFAULT_MAX_TORNADO_VARS, rlabels=None, xlabel='Contribution to variance', figsize=None,
                show=True, filename=None, extra=None, extraColor='grey', extraLoc='right'):
    '''
    :param data: A sorted DataFrame or Series indexed by variable name, with
                 column named 'value' and if rlabels is set, a column of that
                 name holding descriptive labels to display.
    :param labelsize: font size for labels
    :param title: If not None, the title to show
    :param color: The color of the horizontal bars
    :param height: Bar height
    :param maxVars: The maximum number of variables to display
    :param rlabels: If not None, the name of a column holding values to show on the right
    :param xlabel: Label for X-axis
    :param figsize: tuple for desired figure size. Defaults to (12,6) if rlabels else (8,6).
    :param show: If True, the figure is displayed on screen
    :param filename: If not None, the figure is saved to this file
    :param extra: Extra text to display in a lower corner of the plot (see extraLoc)
    :param extraColor: (str) color for extra text
    :param extraLoc: (str) location of extra text, i.e., 'right', or 'left'.
    :return: nothing
    '''
    count, cols = data.shape

    if 0 < maxVars < count:
        data = data[:maxVars]            # Truncate the DF to the top "maxVars" rows
        count = maxVars

    # Reverse the order so the larger (abs) values are at the top
    revIndex = list(reversed(data.index))
    data = data.loc[revIndex]

    itemNums = list(range(count))
    # ypos = np.array(itemNums) - 0.08   # goose the values to better center labels

    if not figsize:
        figsize = (12, 6) if rlabels else (8, 6)

    #fig = plt.figure(figsize=figsize)
    #fig = plt.figure(facecolor='white', figsize=figsize)
    #plt.plot()

    # if it's a dataframe, we expect to find the data in the value column
    values = data if isinstance(data, pd.Series) else data[colname]

    if color is None:
        color = sns.color_palette("deep", 1)

    # TBD: This looks like it has been resolved; try this again using seaborn
    # tried pandas; most of the following manipulations can be handled in one call, but
    # it introduced an ugly dashed line at x=0 which I didn't see how to remove. Maybe
    # address this again if seaborn adds a horizontal bar chart.
    values.plot(kind="barh", color=sns.color_palette("deep", 1), figsize=figsize,
                xlim=(-1, 1), ylim=(-1, count), xticks=np.arange(-0.8, 1, 0.2))

    plt.xlabel(xlabel)

    right = 0.6 if rlabels else 0.9
    plt.subplots_adjust(left=0.3, bottom=0.1, right=right, top=0.9)  # more room for rlabels

    fig = plt.gcf()
    ax  = plt.gca()

    ax.xaxis.tick_top()
    ax.tick_params(axis='x', labelsize=labelsize)
    ax.tick_params(axis='y', labelsize=labelsize)
    ax.set_yticklabels(data.index)
    ax.set_yticks(itemNums)

    if rlabels:
        ax2 = plt.twinx()
        plt.ylim(-1, count)
        ax2.tick_params(axis='y', labelsize=labelsize)
        ax2.set_yticklabels(data[rlabels])
        ax2.set_yticks(itemNums)

        for t in ax2.xaxis.get_major_ticks() + ax2.yaxis.get_major_ticks():
            t.tick1On = False
            t.tick2On = False

    # show vertical grid lines only
    ax.yaxis.grid(False)
    ax.xaxis.grid(True)

    # Remove tickmarks from both axes
    for t in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks():
        t.tick1On = False
        t.tick2On = False

    if title:
        plt.title(title, y=1.05)  # move title up to avoid tick labels

    printExtraText(fig, extra, loc=extraLoc, color=extraColor)

    if show:
        plt.show()

    if filename:
        _logger.debug("Saving tornado plot to %s" % filename)
        fig.savefig(filename)

    plt.close(fig)
コード例 #6
0
def plotTimeSeries(datasets, timeCol, unit, valueCol='value', estimator=np.mean, estimator_linewidth=1.5,
                   ci=90, legend_loc='upper left', legend_labels=None, legend_name=None, title=None,
                   xlabel=None, ylabel=None, label_font=None, title_font=None, xlim=None, ylim=None,
                   ymin=None, ymax=None, text_label=None, show_figure=True, filename=None,
                   save_fig_kwargs=None, figure_style='darkgrid', palette_name=None, extra=None):
    """
    Plot one or more timeseries with flexible representation of uncertainty.

    This function can take a single ndarray array or a list of ndarray arrays
    and plot them against time (years) with a specified confidence interval.

    Parameters
    ----------
    datasets : ndarray, dataframe, or list of ndarrays or dataframes
        Data for the plot. Rows represent samples, columns represent years
    years : series-like
        x values for a plot when data is an array.
    estimator : function
        Function operates column wise on the datasets to produce each line in the figure
    ci : float or list of floats in [0, 100]
        Confidence interval size(s). If a list, it will stack the error
        plots for each confidence interval. Only relevant for error styles
        with "ci" in the name.
    legend_loc : String or float
        Location of the legend on the figure
    legend_labels : string or list of strings
        Either the name of the field corresponding to the data values in
        the data DataFrame (i.e. the y coordinate) or a string that forms
        the y axis label when data is an array.
    legend_name : string
        Legend title.
    title : string
        Plot title
    xlabel : string
        x axis label
    ylabel : string
        y axis label
    text_label : string or list of strings
        if a list of strings, each string gets put on a separate line
    show_figure : bool
        Boolean indicating whether figure should be shown
    filename : string
        Filename used in saving the figure
    save_fig_kwargs : dict
        Other keyword arguments are passed to savefig() call
    figure_style :
        Seaborn figure background styles, options include:
        darkgrid, whitegrid, dark, white
    palette_name : seaborn palette
        Palette for the main plots and error representation

    Returns
    -------
    fig : matplotlib figure

    """

    # Set plot style
    sns.set_style(figure_style)

    # Set up dataset
    if legend_labels is None:
        legend_labels = [None]*len(datasets)
        legend = False
    else:
        legend = True

    if isinstance(legend_labels, str):
        legend_labels = [legend_labels]

    if isinstance(datasets, np.ndarray) or isinstance(datasets, pd.DataFrame):
        datasets = [datasets]

    # Colors
    #colors = sns.color_palette(name=palette_name, n_colors=len(datasets))      # strangely claims name is not a known keyword.
    colors = sns.color_palette(n_colors=len(datasets))

    # Create the plots
    #if fig is None:
    fig, ax = plt.subplots()

    # TBD: this is probably ok, but shouldn't save if a subplot, i.e., if fig & ax were passed in
    for color, data, series_name in zip(colors, datasets, legend_labels):
        tsm.tsplot(data, time=timeCol, value=valueCol, unit=unit, ci=ci, ax=ax, color=color,
                   condition=series_name, estimator=estimator, linewidth=estimator_linewidth)
        # standard version computes CI with different semantics
        #sns.tsplot(data, time=timeCol, value=valueCol, unit=unit, ci=ci, ax=ax, color=color,
        #           condition=series_name, estimator=estimator, linewidth=estimator_linewidth)

    # Add the plot labels
    if label_font is None:
        label_font = dict()
    label_font.setdefault('size', 'medium')

    if title_font is None:
        title_font = dict()
    title_font.setdefault('size', 'large')

    if xlabel is not None:
        ax.set_xlabel(xlabel, fontdict=label_font)

    if ylabel is not None:
        ax.set_ylabel(ylabel, fontdict=label_font)

    if title is not None:
        ax.set_title(title, fontdict=title_font)

    printExtraText(fig, extra, color='grey', loc='right')

    if text_label is not None:
        axis = ax.axis()
        if not isinstance(text_label, str):
            text_label = '\n'.join(text_label)
        ax.text(axis[0]+(axis[1]-axis[0])*.03, axis[2]+(axis[3]-axis[2])*.7, text_label, fontdict=label_font)

    if legend:
        legend1 = ax.legend(loc=legend_loc, title=legend_name, prop={'size': 'medium'})

    for label in (ax.get_xticklabels() + ax.get_yticklabels()):
        label.set(fontsize='medium')

    # Axis limits
    if xlim is not None:
        ax.set_xlim(xlim[0], xlim[1])

    if ymin is not None or ymax is not None:
        ax.set_autoscale_on(False)
        ax.set_ylim(ymin, ymax)

    elif ylim is not None:
        ax.set_ylim(ylim[0], ylim[1])

    # Save the file
    if filename:
        if isinstance(save_fig_kwargs, dict):
            fig.savefig(filename, **save_fig_kwargs)
        else:
            fig.savefig(filename)

    # Display the figure
    if show_figure:
        plt.show()

    return fig