Exemplo n.º 1
0
def scatter_plots_lu(lower, upper, lower_variables=None, upper_variables=None, lowwt=None, uppwt=None,
                    lowlabels=None, upplabels=None, nmax=None,
                    pad=0.0, align_orient=False, titles=None, titlepads=None, titlesize=None,
                    s=None, c=None, alpha=None, cbar=True,
                    cbar_label=None, cmap=None, clim=None, stat_blk=None, stat_xy=None,
                    stat_ha=None, stat_fontsize=None, roundstats=None, sigfigs=None,
                    xlim=None, ylim=None, label='_nolegend_', output_file = None, out_kws = None,
                    grid=True, axis_xy=None, figsize=None, return_handle=False, **kwargs):
    '''
    Function which wraps the scatter_plot function, creating an upper/lower matrix triangle of
    scatterplots for comparing the scatter of multiple variables in two data sets.

    Parameters:
        lower(np.ndarray or pd.DataFrame or gs.DataFile): 2-D data array, which should be
            dimensioned as (ndata, nvar). Alternatively, specific variables may be selected
            with the variables argument. If a DataFile is passed and data.variables has a length
            greater than 1, those columns will be treated as the variables to plot.
            This data is plotted in the lower triangle.
        upper(np.ndarray or pd.DataFrame or gs.DataFile): see the description for lower, although
            this data is plotted in the upper triangle.

    Keyword arguments:
        lower_variables(nvar-tuple str): indicates the column names to treat as variables in lower
        upper_variables(nvar-tuple str): indicates the column names to treat as variables in upper
        lowwt(np.ndarray or pd.Series or str or bool): array with weights that are used in the
            calculation of displayed statistics for the lower data. Alternatively, a str may
            specify the weight column in lower. If lower is a DataFile and lower.wt is not None,
            then wt=True may be used to apply those weights.
        uppwt(np.ndarray or pd.DataFrame or str or bool): see the description for
            lowwt, although these weights are applied to upper.
        lowlabels(nvar-tuple str): labels for lower, which are drawn from lower if None
        upplabels(nvar-tuple str): labels for upper, which are drawn from upper if None
        nmax (int): specify the maximum number of scatter points that should be displayed, which
            may be necessary due to the time-requirements of plotting many data. If specified,
            a nmax-length random sub-sample of the data is plotted. Note that this does not impact
            summary statistics.
        pad(float or 2-tuple): space between each panel, which may be negative or positive. A tuple
            of (xpad, ypad) may also be used.
        align_orient(bool): align the orientation of plots in the upper and lower triangle (True),
            which causes the lower triangle plots to be flipped (x and y axes) from their
            standard symmetric orientation.
        titles(2-tuple str): titles of the lower and upper triangles (lower title, upper title)
        titlepads(2-tuple float): padding of the titles to the left of the lower triangle
            titlepads[0] and above the upper triangle (titlepads[1]). Typical required numbers
            are in the range of 0.01 to 0.5, depending on figure dimensioning.
        titlesize(int): size of the title font
        s(float or np.ndarray or pd.Series): size of each scatter point. Based on
            Parameters['plotting.scatter_plot.s'] if None.
        c(color or np.ndarray or pd.Series): color of each scatter point, as an array or valid
            Matplotlib color. Alternatively, 'KDE' may be specified to color each point according
            to its associated kernel density estimate. Based on Parameters['plotting.scatter_plot.c']
            if None.
        alpha(float): opacity of the scatter. Based on Parameters['plotting.scatter_plot.alpha'] if None.
        cmap(str): A matplotlib colormap object or a registered matplotlib
        clim(2-tuple float): Data minimum and maximum values
        cbar(bool): plot a colorbar for the color of the scatter (if variable)? (default=True)
        cbar_label(str): colorbar label(automated if KDE coloring)
        stat_blk(str or tuple): statistics to place in the plot, which should be 'all' or
            a tuple that may contain ['count', 'pearson', 'spearman']. Based on
            Parameters['plotting.scatter_plot.stat_blk'] if None. Set to False to disable.
        stat_xy(2-tuple float): X, Y coordinates of the annotated statistics in figure
            space. Based on Parameters['plotting.scatter_plot.stat_xy'] if None.
        stat_ha(str): Horizontal alignment parameter for the annotated statistics. Can be
            ``'right'``, ``'left'``, or ``'center'``. If None, based on
            Parameters['plotting.stat_ha']
        stat_fontsize(float): the fontsize for the statistics block. If None, based on
            Parameters['plotting.stat_fontsize']. If less than 1, it is the fraction of the
            matplotlib.rcParams['font.size']. If greater than 1, it the absolute font size.
        roundstats(bool): Indicate if the statistics should be rounded to the number of digits or
            to a number of significant figures (e.g., 0.000 vs. 1.14e-5). The number of digits or
            figures used is set by the parameter ``sigfigs``. sigfigs (int): Number of significant
            figures or number of digits (depending on ``roundstats``) to display for the float
            statistics. Based on Parameters['plotting.roundstats'] and Parameters['plotting.roundstats']
            and Parameters['plotting.sigfigs'] if None.
        grid(bool): plot grid lines in each panel? Based on Parameters['plotting.grid'] if None.
        axis_xy(bool): if True, mimic a GSLIB-style scatter_plot, where only the bottom and left axes
            lines are displayed. Based on Parameters['plotting.axis_xy'] if None.
        xlim(2-tuple float): x-axis limits - xlim[0] to xlim[1]. Based on the data if None
        ylim(2-tuple float): y-axis limits - ylim[0] to ylim[1]. Based on the data if None.
        label(str): label of scatter for legend
        output_file (str): Output figure file name and location
        out_kws (dict): Optional dictionary of permissible keyword arguments to pass to
            :func:`gs.export_image() <pygeostat.plotting.export_image.export_image>`
        figsize(2-tuple float): size of the figure, if creating a new one when ax = None
        return_handles(bool) : return figure handles? (default=False)
        **kwargs: Optional permissible keyword arguments to pass to either: (1) matplotlib's
            scatter function

    Return:
        matplotlib figure handle

    **Examples:**

    Plot with varying orientations that provide correct symmetry (above) and ease of comparison
    (below). Here, the data is treated as both the data and a realization (first two arguments)
    for the sake of demonstration.

    .. plot::

        import pygeostat as gs
        import numpy as np

        # Load the data, which registers the variables attribute
        data_file1 = gs.ExampleData('point3d_ind_mv')
        data_file2 = gs.ExampleData('point3d_ind_mv')
        mask = np.random.rand(len(data_file2))<0.3
        data_file2.data = data_file2.data[mask]

        # Plot with the standard orientation
        fig = gs.scatter_plots_lu(data_file1, data_file2, titles=('Data', 'Realization'), s=10, nmax=1000,
                             stat_xy=(0.95, 0.95), pad=(-1, -1), figsize=(10, 10))

        # Plot with aligned orientation to ease comparison
        fig = gs.scatter_plots_lu(data_file1, data_file2, titles=('Data', 'Realization'), s=10, nmax=1000,
                             stat_xy=(0.95, 0.95), pad=(-1, -1), figsize=(10, 10), cmap='jet',
                             align_orient=True)

    '''
    import pygeostat as gs
    import matplotlib as mpl
    # Parse the data, variables and wt inputs, returning appropriate inputs
    lower, lowwt, lowlabels = _handle_variables_wt(lower, lower_variables, lowwt, lowlabels)
    upper, uppwt, upplabels = _handle_variables_wt(upper, upper_variables, uppwt, upplabels)
    nvar = upper.shape[1]
    if lower.shape[1] != nvar:
        raise ValueError('upper and lower were coerced into differing number of variables!')
    # Iterate over the pairs
    fig, axes = plt.subplots(nvar, nvar, figsize=figsize)
    for i in range(nvar):
        axes[i][i].axis('off')
        for j in range(nvar):
            if i < j:
                _, plot = scatter_plot(upper.iloc[:, j], upper.iloc[:, i], wt=uppwt, s=s, c=c, nmax=nmax,
                                  alpha=alpha, clim=clim, cmap=cmap, cbar=False, cbar_label=False,
                                  stat_blk=stat_blk, stat_xy=stat_xy, stat_fontsize=stat_fontsize,
                                  stat_ha=stat_ha, roundstats=roundstats, sigfigs=sigfigs,
                                  xlim=xlim, ylim=ylim, ax=axes[i][j], xlabel=upplabels[j],
                                  ylabel=upplabels[i], grid=grid, axis_xy=axis_xy,
                                  return_plot=True, **kwargs)
                if i + 1 == j:
                    _tickoff(axes[i][j], xtickoff=False, ytickoff=False)
                else:
                    _tickoff(axes[i][j], xtickoff=True, ytickoff=True)
            elif i > j and align_orient:
                _, plot = scatter_plot(lower.iloc[:, i], lower.iloc[:, j], wt=lowwt, s=s, c=c, nmax=nmax,
                                  alpha=alpha, clim=clim, cmap=cmap, cbar=False, cbar_label=False,
                                  stat_blk=stat_blk, stat_xy=stat_xy, stat_fontsize=stat_fontsize,
                                  stat_ha=stat_ha, roundstats=roundstats, sigfigs=sigfigs,
                                  xlim=xlim, ylim=ylim, ax=axes[i][j], xlabel=lowlabels[i],
                                  ylabel=lowlabels[j], grid=grid, axis_xy=axis_xy,
                                  return_plot=True, **kwargs)
                if j == 0 and i == nvar-1:
                    _tickoff(axes[i][j], xtickoff=False, ytickoff=False)
                elif i == nvar-1:
                    _tickoff(axes[i][j], xtickoff=False, ytickoff=True)
                elif j == 0:
                    _tickoff(axes[i][j], xtickoff=True, ytickoff=False)
                else:
                    _tickoff(axes[i][j], xtickoff=True, ytickoff=True)
            elif i > j and not align_orient:
                _, plot = scatter_plot(lower.iloc[:, j], lower.iloc[:, i], wt=lowwt, s=s, c=c, nmax=nmax,
                                  alpha=alpha, clim=clim, cmap=cmap, cbar=False, cbar_label=False,
                                  stat_blk=stat_blk, stat_xy=stat_xy, stat_fontsize=stat_fontsize,
                                  stat_ha=stat_ha, roundstats=roundstats, sigfigs=sigfigs,
                                  xlim=xlim, ylim=ylim, ax=axes[i][j], xlabel=lowlabels[j],
                                  ylabel=lowlabels[i], grid=grid, axis_xy=axis_xy,
                                  return_plot=True, **kwargs)
                if j == 0 and i == nvar-1:
                    _tickoff(axes[i][j], xtickoff=False, ytickoff=False)
                elif i == nvar-1:
                    _tickoff(axes[i][j], xtickoff=False, ytickoff=True)
                elif j == 0:
                    _tickoff(axes[i][j], xtickoff=True, ytickoff=False)
                else:
                    _tickoff(axes[i][j], xtickoff=True, ytickoff=True)
    try:
        fig.tight_layout(h_pad=pad[1], w_pad=pad[0])
    except:
        fig.tight_layout(h_pad=pad, w_pad=pad)
    fig.subplots_adjust(top=.95, right=.95, left=.07)
    if titles is not None:
        if len(titles) != 2:
            raise ValueError('titles should be a 2-list of strings!')
        if titlepads is not None:
            if titlepads[0] is None:
                titlepads[0] = 3.*fig.dpi
            if titlepads[1] is None:
                titlepads[1] = 0.0
        else:
            titlepads = (0.08*fig.dpi, 0.01)
        if titlesize is None:
            titlesize = mpl.rcParams['font.size']
        gs.supaxislabel('y', titles[0], label_prop={'weight': 'bold', 'fontsize': titlesize},
                        fig=fig, labelpad=titlepads[0])
        fig.suptitle(titles[1], weight='bold', fontsize=titlesize, y=0.98+titlepads[1])

    # Figure out if KDE was used
    kde = False
    if c is None:
        c = Parameters['plotting.scatter_plot.c']
        if isinstance(c, str):
            if c.lower() == 'kde':
                kde = True
    if (not kde and not isinstance(c, pd.DataFrame) and not isinstance(c, np.ndarray) and
            not isinstance(c, pd.Series)):
        cbar = False
    # Colorbar
    if cbar:
        fig.subplots_adjust(bottom=.15)
        #ax = fig.add_axes([0.07, .15, .88, .8])
        cbar_ax = fig.add_axes([0.1, 0.04, 0.8, 0.02])
        if kde:
            cbar = fig.colorbar(plot, cax=cbar_ax, ticks=[0, .5, 1], orientation='horizontal')
            cbar.ax.set_xticklabels(['Low', 'Med.', 'High'])
            cbar.ax.set_title('Kernel Density Estimate')
                #cbar.ax.set_title('KDE')
        else:
            cbar = fig.colorbar(plot, cax=cbar_ax, orientation='horizontal')
            try:
                cbar_label = gs.get_label(c)
            except:
                pass
            if cbar_label is not None:
                cbar.ax.set_title(cbar_label)
    else:
        fig.subplots_adjust(bottom=.05)

    # Handle dictionary defaults
    if out_kws is None:
        out_kws = dict()
        
    if output_file or ('pdfpages' in out_kws):
        gs.export_image(output_file, **out_kws)
        
    return fig
Exemplo n.º 2
0
def scatter_plots(data, variables=None, wt=None, labels=None, nmax=None, pad=0.0, s=None, c=None,
             alpha=None, cmap=None, clim=None, cbar=True, cbar_label=None,
             stat_blk=None, stat_xy=None, stat_ha=None, stat_fontsize=None,
             roundstats=None, sigfigs=None,
             grid=None, axis_xy=None, xlim=None, ylim=None, label='_nolegend_', output_file = None, out_kws = None,
             figsize=None, **kwargs):
    '''
    Function which wraps the scatter_plot function, creating an upper matrix triangle of scatterplots
    for multiple variables.

    Parameters:
        data(np.ndarray or pd.DataFrame or gs.DataFile) : 2-D data array, which should be
            dimensioned as (ndata, nvar). Alternatively, specific variables may be selected
            with the variables argument. If a DataFile is passed and data.variables has a length
            greater than 1, those columns will be treated as the variables to plot.

    Keyword arguments:
        variables(str list): indicates the column names to treat as variables in data
        wt(np.ndarray or pd.Series or str or bool): array with weights
            that are used in the calculation of displayed statistics. Alternatively, a str may
            specify the weight column in lower. If data is a DataFile and data.wts is not None,
            then wt=True may be used to apply those weights.
        labels(tuple or nvar-list): labels for data, which are drawn from data if None
        nmax (int): specify the maximum number of scatter points that should be displayed, which
            may be necessary due to the time-requirements of plotting many data. If specified,
            a nmax-length random sub-sample of the data is plotted. Note that this does not impact
            summary statistics.
        pad(float or 2-tuple): space between each panel, which may be negative or positive. A tuple
            of (xpad, ypad) may also be used.
        align_orient(bool): align the orientation of plots in the upper and lower triangle (True),
            which causes the lower triangle plots to be flipped (x and y axes) from their
            standard symmetric orientation.
        titles(2-tuple str): titles of the lower and upper triangles (lower title, upper title)
        titlepads(2-tuple float): padding of the titles to the left of the lower triangle
            titlepads[0] and above the upper triangle (titlepads[1]). Typical required numbers
            are in the range of 0.01 to 0.5, depending on figure dimensioning.
        titlesize(int): size of the title font
        s(float or np.ndarray or pd.Series): size of each scatter point. Based on
            Parameters['plotting.scatter_plot.s'] if None.
        c(color or np.ndarray or pd.Series): color of each scatter point, as an array or valid
            Matplotlib color. Alternatively, 'KDE' may be specified to color each point according
            to its associated kernel density estimate. Based on Parameters['plotting.scatter_plot.c']
            if None.
        alpha(float): opacity of the scatter. Based on Parameters['plotting.scatter_plot.alpha'] if None.
        cmap(str): A matplotlib colormap object or a registered matplotlib
        clim(2-tuple float): Data minimum and maximum values
        cbar(bool): plot a colorbar for the color of the scatter (if variable)? (default=True)
        cbar_label(str): colorbar label(automated if KDE coloring)
        stat_blk(str or tuple): statistics to place in the plot, which should be 'all' or
            a tuple that may contain ['count', 'pearson', 'spearman']. Based on
            Parameters['plotting.scatter_plot.stat_blk'] if None. Set to False to disable.
        stat_xy(2-tuple float): X, Y coordinates of the annotated statistics in figure
            space. Based on Parameters['plotting.scatter_plot.stat_xy'] if None.
        stat_ha(str): Horizontal alignment parameter for the annotated statistics. Can be
            ``'right'``, ``'left'``, or ``'center'``. If None, based on
            Parameters['plotting.stat_ha']
        stat_fontsize(float): the fontsize for the statistics block. If None, based on
            Parameters['plotting.stat_fontsize']. If less than 1, it is the fraction of the
            matplotlib.rcParams['font.size']. If greater than 1, it the absolute font size.
        roundstats(bool): Indicate if the statistics should be rounded to the number of digits or
            to a number of significant figures (e.g., 0.000 vs. 1.14e-5). The number of digits or
            figures used is set by the parameter ``sigfigs``. sigfigs (int): Number of significant
            figures or number of digits (depending on ``roundstats``) to display for the float
            statistics. Based on Parameters['plotting.roundstats'] and Parameters['plotting.roundstats']
            and Parameters['plotting.sigfigs'] if None.
        grid(bool): plot grid lines in each panel? Based on Parameters['plotting.grid'] if None.
        axis_xy(bool): if True, mimic a GSLIB-style scatter_plot, where only the bottom and left axes
            lines are displayed. Based on Parameters['plotting.axis_xy'] if None.
        xlim(2-tuple float): x-axis limits - xlim[0] to xlim[1]. Based on the data if None
        ylim(2-tuple float): y-axis limits - ylim[0] to ylim[1]. Based on the data if None.
        label(str): label of scatter for legend
        output_file (str): Output figure file name and location
        out_kws (dict): Optional dictionary of permissible keyword arguments to pass to
            :func:`gs.export_image() <pygeostat.plotting.export_image.export_image>`
        figsize(2-tuple float): size of the figure, if creating a new one when ax = None
        return_handles(bool) : return figure handles? (default=False)
        **kwargs: Optional permissible keyword arguments to pass to either: (1) matplotlib's
            scatter function

    Return:
        matplotlib figure handle


    **Example:**

    Only one basic example is provided here, although all kwargs applying to the underlying scatter_plot
    function may be applied to scatter_plots.

    .. plot::

        import pygeostat as gs

        # Load the data, which registers the variables attribute
        data_file = gs.ExampleData('point3d_ind_mv')

        # Plot with the default KDE coloring
        fig = gs.scatter_plots(data_file, nmax=1000, stat_xy=(0.95, 0.95), pad=(-1, -1), s=10,
                          figsize=(10, 10))
    '''
    import pandas as pd
    import pygeostat as gs
    # Parse the data, variables and wt inputs, returning appropriate inputs
    data, wt, labels = _handle_variables_wt(data, variables, wt, labels)
    nvar = data.shape[1]
    # Iterate over the pairs
    fig, axes = plt.subplots(nvar-1, nvar-1, figsize=figsize)
    for i in np.arange(0, nvar-1):
        for j in np.arange(1, nvar):
                if i < j:
                    _, plot = scatter_plot(data.iloc[:, j], data.iloc[:, i], wt=wt, s=s, c=c,
                                      alpha=alpha, clim=clim, cmap=cmap, cbar=False, nmax=nmax,
                                      cbar_label=False, stat_blk=stat_blk, stat_xy=stat_xy,
                                      stat_fontsize=stat_fontsize, return_plot=True,
                                      stat_ha=stat_ha, roundstats=roundstats, sigfigs=sigfigs,
                                      xlim=xlim, ylim=ylim, ax=axes[i][j-1], xlabel=labels[j],
                                      ylabel=labels[i], grid=grid, axis_xy=axis_xy, **kwargs)
                    if i == j-1:
                        _tickoff(axes[i][j-1], xtickoff=False, ytickoff=False)
                    else:
                        _tickoff(axes[i][j-1], xtickoff=True, ytickoff=True)
                else:
                    axes[i][j-1].axis('off')
    try:
        fig.tight_layout(h_pad=pad[1], w_pad=pad[0])
    except:
        fig.tight_layout(h_pad=pad, w_pad=pad)
    # Figure out if KDE was used
    kde = False
    if c is None:
        c = Parameters['plotting.scatter_plot.c']
        if isinstance(c, str):
            if c.lower() == 'kde':
                kde = True
    if (not kde and not isinstance(c, pd.DataFrame) and not isinstance(c, np.ndarray) and
            not isinstance(c, pd.Series)):
        cbar = False
    # Colorbar
    if cbar:
        cbar_ax = fig.add_axes([0.2, .15, .03, .25])
        if kde:
            cbar = fig.colorbar(plot, cax=cbar_ax, ticks=[0, .5, 1])
            cbar.ax.set_yticklabels(['Low', 'Med.', 'High'])
            cbar.set_label('Kernel Density Estimate', ha='center', va='top', labelpad=2)
                #cbar.ax.set_title('KDE')
        else:
            cbar = fig.colorbar(plot, cax=cbar_ax)
            try:
                cbar_label = gs.get_label(c)
            except:
                pass
            if cbar_label is not None:
                cbar.set_label(cbar_label, ha='center', va='top', labelpad=2)

    # Handle dictionary defaults
    if out_kws is None:
        out_kws = dict()
        
    if output_file or ('pdfpages' in out_kws):
        gs.export_image(output_file, **out_kws)
    return fig
Exemplo n.º 3
0
def histogram_plot_simulation(simulated_data, reference_data, reference_variable=None, reference_weight=None, reference_n_sample=None,
                            simulated_column=None, griddef=None, nreal=None,
                            n_subsample=None, simulated_limits=False, ax=None,
                            figsize=None, xlim=None, title=None, xlabel=None, stat_blk='all',
                            stat_xy=(0.95, 0.05), reference_color=None, simulation_color=None, alpha=None, lw=1,
                            plot_style=None, custom_style=None, output_file=None, out_kws=None, sim_kws=None,
                            **kwargs):
    """
    histogram_plot_simulation emulates the pygeostat histogram_plot program as a means of checking histogram
    reproduction of simulated realizations to the original histogram. The use of python generators
    is a very flexible and easy means of instructing this plotting function as to what to plot.

    The function accepts five types of simulated input passed to the ``simulated_data`` argument:

        #. 1-D array like data (numpy or pandas) containing 1 or more realizations of simulated
           data.
        #. 2-D array like data (numpy or pandas) with each column being a realization and each row
           being an observation.
        #. List containing location(s) of realization file(s).
        #. String containing the location of a folder containing realization files. All files in
           the folder are read in this case.Can contain
        #. String with a wild card search (e.g., './data/realizations/*.out')
        #. Python generator object that yields a 1-D numpy array.

    The function accepts two types of reference input passed to the ``reference_data`` argument:

        #. Array like data containing the reference variable
        #. String containing the location of the reference data file (e.g., './data/data.out')

    This function uses pygeostat for plotting and numpy to calculate statistics.

    The only parameters required are ``reference_data`` and ``simulated_data``. If files are to be read or a 1-D
    array is passed, the parameters ``griddef`` and ``nreal`` are required. ``simulated_column`` is required
    for reading files as well. It is assumed that an equal number of realizations are within each
    file if multiple file locations are passed. Sub-sampling of datafiles can be completed by
    passing the parameter ``n_subsample``. If a file location is passed to ``reference_data``, the parameters
    ``reference_variable`` and ``reference_n_sample`` are required. All other arguments are optional or determined
    automatically if left at their default values. If ``xlabel`` is left to its default value of
    ``None``, the column information will be used to label the axes if present. Three keyword
    dictionaries can be defined. (1) ``sim_kws`` will be passed to pygeostat histogram_plot used for
    plotting realizations (2) ``out_kws`` will be passed to the pygeostat exportfig function and
    (3) ``**kwargs`` will be passed to the pygeostat histogram_plot used to plot the reference data.


    Two statistics block sets are available: ``'minimal'`` and the default ``'all'``. The
    statistics block can be customized to a user defined list and order. Available statistics are
    as follows:

    >>> ['nreal', 'realavg', 'realavgstd', 'realstd', 'realstdstd', 'ndat', 'refavg', 'refstd']

    Please review the documentation of the :func:`gs.set_plot_style() <pygeostat.plotting.set_plot_style>` and
    :func:`gs.export_image() <pygeostat.plotting.export_image>` functions for details on their
    parameters so that their use in this function can be understood.

    Parameters:
        simulated_data: Input simulation data
        reference_data: Input reference data

    Keyword Arguments:
        reference_variable (int, str): Required if sub-sampling reference data. The column containing the data
            to be sub-sampled
        reference_weight: 1D dataframe, series, or numpy array of declustering weights for the data. Can also
            be a string of the column in the reference_data if reference_data is a string, or a bool if reference_data.weights
            is a string
        reference_n_sample (int): Required if sub-sampling reference data. The number of data within the
            reference data file to sample from
        griddef (GridDef): A pygeostat class GridDef created using :class:`gs.GridDef
            <pygeostat.data.grid_definition.GridDef>`
        simulated_column (int): column number in the simulated data file
        nreal (int): Required if sub-sampling simulation data. The total number of realizations
            that are being plotted. If a HDF5 file is passed, this parameter can be used to limit
            the amount of realizations plotted (i.e., the first ``nreal`` realizations)
        n_subsample (int): Required if sub-sampling is used. The number of sub-samples to draw.
        ax (mpl.axis): Matplotlib axis to plot the figure
        figsize (tuple): Figure size (width, height)
        xlim (float tuple): Minimum and maximum limits of data along the x axis
        title (str): Title for the plot
        xlabel (str): X-axis label
        stat_blk (str or list): Indicate what preset statistics block to write or a specific list
        stat_xy (str or float tuple): X, Y coordinates of the annotated statistics in figure
            space. The default coordinates specify the bottom right corner of the text block
        reference_color (str): Colour of original histogram
        simulation_color (str): Colour of simulation histograms
        alpha (float): Transparency for realization variograms (0 = Transparent, 1 = Opaque)
        lw (float): Line width in points. The width provided in this parameter is used for the
            reference variogram, half of the value is used for the realization variograms.
        plot_style (str): Use a predefined set of matplotlib plotting parameters as specified by
            :class:`gs.GridDef <pygeostat.data.grid_definition.GridDef>`. Use ``False`` or ``None``
            to turn it off
        custom_style (dict): Alter some of the predefined parameters in the ``plot_style`` selected
        output_file (str): Output figure file name and location
        out_kws (dict): Optional dictionary of permissible keyword arguments to pass to
            :func:`gs.export_image() <pygeostat.plotting.export_image.export_image>`
        sim_kws: Optional dictionary of permissible keyword arguments to pass to
            :func:`gs.histogram_plot() <pygeostat.plotting.histogram_plot.histogram_plot>` for plotting realization
            histograms and by extension, matplotlib's plot function if the keyword passed is not
            used by :func:`gs.histogram_plot() <pygeostat.plotting.histogram_plot.histogram_plot>`
        **kwargs: Optional dictionary of permissible keyword arguments to pass to
            :func:`gs.histogram_plot() <pygeostat.plotting.histogram_plot.histogram_plot>` for plotting the reference
            histogram and by extension, matplotlib's plot function if the keyword passed is not
            used by :func:`gs.histogram_plot() <pygeostat.plotting.histogram_plot.histogram_plot>`

    Returns:
        ax (ax): matplotlib Axes object with the histogram reproduction plot

    **Examples:**

    .. plot:: 
    
        import pygeostat as gs
        import pandas as pd

        # Global setting using Parameters
        gs.Parameters['data.griddef'] = gs.GridDef([10,1,0.5, 10,1,0.5, 2,1,0.5])
        gs.Parameters['data.nreal'] = nreal = 100
        size = gs.Parameters['data.griddef'].count();

        reference_data = pd.DataFrame({'value':np.random.normal(0, 1, size = size)})

        # Create example simulated data
        simulated_data =pd.DataFrame(columns=['value'])
        for i, offset in enumerate(np.random.rand(nreal)*0.04 - 0.02):
            simulated_data= simulated_data.append(pd.DataFrame({'value':np.random.normal(offset, 1, size = size)}))

        gs.histogram_plot_simulation(simulated_data, reference_data, reference_variable='value',
                                title='Histogram Reproduction', grid=True)
    """
    # -----------------------------------------------------------------------
    #  Sanity checks, file type determination, and try loading fortran
    # -----------------------------------------------------------------------
    import pygeostat as gs
    from . utils import format_plot, _set_stat_fontsize
    # Figure out what type of input simulated_data is
    subsamp = False
    ndim = False
    generator = False
    folder = False
    wildcard = False
    array = False
    filelist = False
    if isinstance(simulated_data, types.GeneratorType):
        nreal = 0
        generator = True
        iterator = simulated_data
    elif isinstance(simulated_data, str) and ('*' in simulated_data):
        wildcard = True
    elif isinstance(simulated_data, str) and os.path.isdir(simulated_data):
        folder = True
    elif isinstance(simulated_data, list):
        filelist = True
    elif any([isinstance(simulated_data, pd.DataFrame), isinstance(simulated_data, np.ndarray), isinstance(simulated_data, pd.Series)]):
        array = True
        if subsamp:
            raise ValueError("Sub-sampling won't work if the data is already in memory")
    else:
        raise ValueError("The passed `simulated_data` is not a valid input format")
    if n_subsample is not None and isinstance(n_subsample, (int, float)):
        subsamp = True
    # Make sure the required parameters are passed
    if griddef is None:
        griddef = Parameters['data.griddef']
        if griddef is None:
            raise ValueError("A gs.GridDef must be passed when reading from files")

    if nreal is None:
            nreal = Parameters['data.nreal']
            if nreal is None:
                raise ValueError("The number of realizations to be read must be specified when"
                                 " reading from files")

    if any([folder, wildcard, filelist]):
        
        if simulated_column is None:
            raise ValueError("The column in the files that contains the simulation data must be"
                             " specified")
    elif array:
        if (len(simulated_data.shape) == 1) or (simulated_data.shape[1] == 1):
            if not isinstance(griddef, gs.GridDef):
                raise ValueError("If a 1-D array is passed, a gs.GridDef must be passed to"
                                 " `griddef`")
            if nreal is None:
                raise ValueError("The number of realizations must be passed if dealing with a"
                                 " 1-D array")
    # Figure out what type of input reference_data is
    if isinstance(reference_data, str) and (n_subsample is not False):
        refsubsamp = True
    else:
        refsubsamp = False
    # Try to load the subsample function if it required
    if subsamp or refsubsamp:
        if not isinstance(griddef, gs.GridDef):
            raise ValueError("A gs.GridDef is required for subsampling.")
        try:
            from pygeostat.fortran.subsample import subsample
        except:
            raise ImportError("The fortran subroutine subsample could not be loaded, please ensure"
                              " it has been compiled correctly.")

    # -----------------------------------------------------------------------
    #  Handle data input
    # -----------------------------------------------------------------------
    # Set-up variables
    realavg, realavgstd, realstd, realstdstd, refavg, refstd = ([] for i in range(6))
    # Handle pd and np input
    if array:
        if isinstance(simulated_data, pd.DataFrame):
            simulated_data = simulated_data.values
        if (len(simulated_data.shape) == 2) and (simulated_data.shape[1] > 1):
            ndim = 2
        else:
            ndim = 1
        # Handle 1-D arrays
        if ndim == 1:
            ncell = griddef.count()
        # Handle 2-D arrays
        if ndim == 2:
            nreal = simulated_data.shape[1]
    # Handle folder and wildcard searches
    if folder:
        if simulated_data[-1] != '/':
            simulated_data = simulated_data + '/'
        simulated_data = simulated_data + '*'
        files = []
        for filepath in glob.glob(simulated_data):
            files.append(filepath)
    if wildcard:
        files = []
        for filepath in glob.glob(simulated_data):
            files.append(filepath)
    if filelist:
        files = simulated_data

    if any([folder, wildcard, filelist]):
        ncell = griddef.count()
        ndim = 1
        # Check and make sure the number of files and the nreal value passed makes sense
        if nreal % len(files) != 0:
            raise ValueError(" The number of realizations passed is not divisible by the number"
                             " of files passed/found. Please make sure there are the same number"
                             " of realizations in each file and that the sum of them match the"
                             " nreal argument passed.")
        # Read the data
        simulated_data = []
        for file in files:
            data = gs.DataFile(file).data.iloc[:, simulated_column - 1]
            if simulated_limits:
                if isinstance(simulated_limits, (int, float)):
                    data = data.loc[data > simulated_limits]
                elif len(simulated_limits) == 2:
                    data = data.loc[data.between(simulated_limits[0], simulated_limits[1])]
                else:
                    raise ValueError('simulated_limits must be a value or tuple')
            simulated_data.extend(data)
        simulated_data = np.array(simulated_data)
        if simulated_limits:
            ncell = len(data)
    # Handle sub-sampling if required
    if subsamp:
        ncell = griddef.count()
        if array:
            if ndim == 1:
                simulated_data = np.reshape(simulated_data, (nreal, griddef.count())).T
                ndim = 2
            newarr = np.zeros((n_subsample, nreal))
            for ireal in range(nreal):
                ridx = np.random.permutation(ncell)[:n_subsample]
                newarr[:, ireal] = simulated_data[ridx, ireal]
            simulated_data = newarr
        else:
            # Sub-sample all of the files and combine into a single numpy array
            files = simulated_data
            file_nreal = int(nreal / len(files))
            simulated_data = []
            for fl in files:
                dump = subsample(fl, simulated_column, ncell, n_subsample, file_nreal, gs.rseed())
                dump = np.transpose(dump)
                simulated_data.extend(dump)
            simulated_data = np.array(simulated_data)
            simulated_data = np.transpose(simulated_data)
            nreal = simulated_data.shape[1]

    # Create a generator for the realizations
    if generator is False:
        def _itersimulated_data():
            for i in range(0, nreal):
                if subsamp or ndim == 2:
                    real = simulated_data[:, i]
                elif ndim == 1:
                    real = simulated_data[(i * ncell):(((i + 1) * ncell) - 1)]
                yield real
        iterator = _itersimulated_data()

    # -----------------------------------------------------------------------
    #  Plot Figure
    # -----------------------------------------------------------------------
    # Set figure style parameters

    # Handle dictionary defaults
    if sim_kws is None:
        sim_kws = dict()
    if out_kws is None:
        out_kws = dict()
    # Set-up figure
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    if simulation_color is None:
        simulation_color = Parameters['plotting.histogram_plot_simulation.simclr']
    if alpha is None:
        alpha = Parameters['plotting.histogram_plot_simulation.alpha']
    # Plot realization histograms
    for real in iterator:
        # Append intermediate realization dist statistics to variables
        if stat_blk:
            realavg.append(np.nanmean(real))
            realstd.append(np.nanstd(real))
        if generator:
            nreal += 1
        gs.histogram_plot(real, ax=ax, icdf=True, lw=(lw / 2), stat_blk=False, color=simulation_color, alpha=alpha, plot_style=False, **sim_kws)
    # Calculate more realization dist statistics if required
    if stat_blk:
        realavgstd = np.std(realavg)
        realavg = np.mean(realavg)
        realstdstd = np.std(realstd)
        realstd = np.mean(realstd)
    # Sub-sample original distribution if needed

    if isinstance(reference_data, str):
        if subsamp:
            reference_data = subsample(reference_data, reference_variable, reference_n_sample, n_subsample, 1, gs.rseed())
            reference_data = reference_data[:, 0]
        else:
            reference_data = gs.DataFile(reference_data)
            if isinstance(reference_variable, str):
                reference_variable = reference_data.gscol(reference_variable) - 1
            if isinstance(reference_weight, str):
                reference_weight = reference_data.gscol(reference_weight) - 1
                reference_weight = reference_data.data.values[:, reference_weight]
            elif isinstance(reference_weight, bool):
                if reference_weight:
                    if isinstance(reference_data.weights, str):
                        reference_weight = reference_data[reference_data.weights]
                    else:
                        raise ValueError('reference_weight=True is only valid if reference_data.weights is a string!')
            reference_data = reference_data.data.values[:, reference_variable]
    elif isinstance(reference_data, gs.DataFile):
        if isinstance(reference_weight, str):
            reference_weight = reference_data[reference_weight].values
        elif isinstance(reference_weight, bool):
            if reference_weight:
                if isinstance(reference_data.weights, str):
                    reference_weight = reference_data[reference_data.weights].values
                else:
                    raise ValueError('reference_weight=True is only valid if reference_data.weights is a string!')
        if isinstance(reference_variable, str):
            reference_data = reference_data[reference_variable].values
        elif isinstance(reference_data.variables, str):
            reference_data = reference_data[reference_data.variables].values
        elif len(list(reference_data.columns)) == 1:
            reference_data = reference_data.data.values
        else:
            raise ValueError('could not coerce reference_data into a 1D array!')

    # Plot the reference histogram
    if reference_color is None:
        reference_color = Parameters['plotting.histogram_plot_simulation.refclr']
    if not isinstance(reference_data, bool) and (reference_data is not None):
        gs.histogram_plot(reference_data, weights=reference_weight, ax=ax, icdf=True, stat_blk=False, color=reference_color,
                   lw=lw, plot_style=False, **kwargs)
    # Calculate reference dist statistics if required
    if stat_blk:
        if not isinstance(reference_data, bool) and (reference_data is not None):
            refavg = gs.weighted_mean(reference_data, reference_weight)
            refstd = np.sqrt(gs.weighted_variance(reference_data, reference_weight))
            ndat = len(reference_data)
        else:
            refavg = np.nan
            refstd = np.nan
            ndat = np.nan
    # Configure plot
    if xlabel is None:
        xlabel = gs.get_label(reference_data)
        if xlabel is None:
            xlabel = gs.get_label(simulated_data)
    # axis_xy and grid are applied by format_plot based on the current Parameters setting
    # no kwarg in this function for now since it's already loaded
    ax = format_plot(ax, xlabel, 'Cumulative Frequency', title, xlim=xlim)
    # Ensure that we have top spline, in case it was removed above
    ax.spines['top'].set_visible(True)
    # Plot statistics block
    if stat_blk:
        statlist = {'nreal': (r'$n_{real} = %0.0f$' % nreal),
                    'realavg': (r'$m_{real} = %0.3f$' % realavg),
                    'realavgstd': (r'$\sigma_{m_{real}} = %0.3f$' % realavgstd),
                    'realstd': (r'$\sigma_{real} = %0.3f$' % realstd),
                    'realstdstd': (r'$\sigma_{\sigma_{real}} = %0.3f$' % realstdstd),
                    'ndat': ('$n_{ref} = %0.0f$' % ndat),
                    'refavg': (r'$m_{ref} = %0.3f$' % refavg),
                    'refstd': (r'$\sigma_{ref} = %0.3f$' % refstd)}
        statsets = {'all': ['nreal', 'realavg', 'realavgstd', 'realstd', 'realstdstd', 'ndat',
                            'refavg', 'refstd'],
                    'minimal': ['nreal']}
        if subsamp:
            statlist['n_subsample'] = '$n_{subsample} = %0.0f$' % n_subsample
            statsets['all'].append('n_subsample')
        txtstats, stat_xy, ha, va = gs.get_statblk(stat_blk, statsets, statlist, stat_xy)
        stat_fontsize = _set_stat_fontsize(None)
        ax.text(stat_xy[0], stat_xy[1], txtstats, va=va, ha=ha, fontsize=stat_fontsize,
                transform=ax.transAxes)
    # Export figure
    if output_file or ('pdfpages' in out_kws):
        gs.export_image(output_file, **out_kws)

    return ax
Exemplo n.º 4
0
def scatter_plot(x, y, wt=None, nmax=None, s=None, c=None, alpha=None, cmap=None, clim=None, cbar=False,
                cbar_label=None, stat_blk=None, stat_xy=None, stat_ha=None, stat_fontsize=None,
                roundstats=None, sigfigs=None, xlim=None, ylim=None, xlabel=None, ylabel=None, output_file=None, out_kws = None,
                title=None, grid=None, axis_xy=None, label='_nolegend_', ax=None, figsize=None,
                return_plot=False, logx=None, logy=None, **kwargs):
    '''
    Scatter plot that mimics the GSLIB scatter_plot program, providing summary statistics, kernel
    density estimate coloring, etc. NaN values are treated as null and removed from the plot and
    statistics.

    Parameters:
        x(np.ndarray or pd.Series): 1-D array with the variable to plot on the x-axis.
        y(np.ndarray or pd.Series): 1-D array with the variable to plot on the y-axis.

    Keyword arguments:
        wt(np.ndarray or pd.DataFrame): 1-D array with weights that are used in the calculation of
            displayed statistics.
        s(float or np.ndarray or pd.Series): size of each scatter point. Based on
            Parameters['plotting.scatter_plot.s'] if None.
        c(color or np.ndarray or pd.Series): color of each scatter point, as an array or valid
            Matplotlib color. Alternatively, 'KDE' may be specified to color each point according
            to its associated kernel density estimate. Based on Parameters['plotting.scatter_plot.c']
            if None.
        nmax (int): specify the maximum number of scatter points that should be displayed, which
            may be necessary due to the time-requirements of plotting many data. If specified,
            a nmax-length random sub-sample of the data is plotted. Note that this does not impact
            summary statistics.
        alpha(float): opacity of the scatter. Based on Parameters['plotting.scatter_plot.alpha'] if None.
        cmap (str): A matplotlib colormap object or a registered matplotlib
        clim (float tuple): Data minimum and maximum values
        cbar (bool): Indicate if a colorbar should be plotted or not
        cbar_label (str): Colorbar title
        stat_blk(str or list): statistics to place in the plot, which should be 'all' or
            a list that may contain ['count', 'pearson', 'spearman', 'noweightflag']. Based on
            Parameters['plotting.scatter_plot.stat_blk'] if None. Set to False to disable.
        stat_xy (float tuple): X, Y coordinates of the annotated statistics in figure
            space. Based on Parameters['plotting.scatter_plot.stat_xy'] if None.
        stat_ha (str): Horizontal alignment parameter for the annotated statistics. Can be
            ``'right'``, ``'left'``, or ``'center'``. If None, based on
            Parameters['plotting.stat_ha']
        stat_fontsize (float): the fontsize for the statistics block. If None, based on
            Parameters['plotting.stat_fontsize']. If less than 1, it is the fraction of the
            matplotlib.rcParams['font.size']. If greater than 1, it the absolute font size.
        roundstats (bool): Indicate if the statistics should be rounded to the number of digits or
            to a number of significant figures (e.g., 0.000 vs. 1.14e-5). The number of digits or
            figures used is set by the parameter ``sigfigs``. sigfigs (int): Number of significant
            figures or number of digits (depending on ``roundstats``) to display for the float
            statistics. Based on Parameters['plotting.roundstats'] and Parameters['plotting.roundstats']
            and Parameters['plotting.sigfigs'] if None.
        xlim(tuple): x-axis limits - xlim[0] to xlim[1]. Based on the data if None
        ylim(tuple): y-axis limits - ylim[0] to ylim[1]. Based on the data if None.
        xlabel(str): label of the x-axis, extracted from x if None
        ylabel(str): label of the y-axis, extracted from y if None
        output_file (str): Output figure file name and location
        out_kws (dict): Optional dictionary of permissible keyword arguments to pass to
            :func:`gs.export_image() <pygeostat.plotting.export_image.export_image>`
        title(str): plot title
        grid(bool): plot grid lines in each panel? Based on Parameters['plotting.grid'] if None.
        axis_xy(bool): if True, mimic a GSLIB-style scatter_plot, where only the bottom and left axes
            lines are displayed. Based on Parameters['plotting.axis_xy'] if None.
        label(str): label of scatter for legend
        ax(Matplotlib axis handle): if None, create a new figure and axis handles
        figsize(tuple): size of the figure, if creating a new one when ax = None
        logx, logy (str): permissible mpl axis scale, like `log`
        **kwargs: Optional permissible keyword arguments to pass to either: (1) matplotlib's
            scatter function

    Return:
        ax(Matplotlib axis handle)

    **Examples:**

    Basic scatter example:

    .. plot::

        import pygeostat as gs

        # Load the data
        data_file = gs.ExampleData('point3d_ind_mv')

        # Select a couple of variables
        x, y = data_file[data_file.variables[0]], data_file[data_file.variables[1]]

        # Scatter plot with default parameters
        gs.scatter_plot(x, y, figsize=(5, 5), cmap='hot')

        # Scatter plot without correlation and with a color bar:
        gs.scatter_plot(x, y, nmax=2000, stat_blk=False, cbar=True, figsize=(5, 5))

        # Scatter plot with the a constant color, transparency and all statistics
        # Also locate the statistics where they are better seen
        gs.scatter_plot(x, y, c='k', alpha=0.2, nmax=2000, stat_blk='all', stat_xy=(.95, .95),
                   figsize=(5, 5))
    '''
    # Import packages
    from scipy.stats import gaussian_kde
    from copy import deepcopy
    import pygeostat as gs
    from . utils import _set_stat_fontsize
    # Figure out the plotting axes
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)
    # Labels if present
    if xlabel is None:
        xlabel = gs.get_label(x)
    if ylabel is None:
        ylabel = gs.get_label(y)
    # Check the input data
    if isinstance(x, pd.DataFrame) or isinstance(x, pd.Series):
        x = x.values
    if x.ndim > 1:
        raise ValueError('x should be one-dimension!')
    if isinstance(y, pd.DataFrame) or isinstance(y, pd.Series):
        y = y.values
    if y.shape != x.shape:
        raise ValueError('x and y should be the same shape!')
    # Check the weights
    if isinstance(wt, pd.DataFrame) or isinstance(wt, pd.Series):
        wt = wt.values
    elif wt is None:
        wt = np.ones(x.shape)
    if wt.shape != x.shape:
        raise ValueError('x, y and wt should be the same shape!')
    # Remove nans if present
    idx = np.logical_and(np.isfinite(x), np.isfinite(y), np.isfinite(wt))
    x, y, wt = x[idx], y[idx], wt[idx]
    # Draw a random sub-sample if present
    xplot, yplot = deepcopy(x), deepcopy(y)
    if isinstance(nmax, int):
        if len(xplot) > nmax:
            idx1 = np.random.randint(0, len(xplot), nmax)
            xplot = xplot[idx1]
            yplot = yplot[idx1]
    else:
        idx1 = np.arange(0, len(xplot))
    # There's probably a lot of edge cases to this testing that are not yet
    # handled
    if isinstance(c, pd.DataFrame) or isinstance(c, pd.Series):
        if cbar_label is None:
            cbar_label = gs.get_label(c)
        c = c.values
    if isinstance(c, np.ndarray):
        c = c[idx]
        c = c[idx1]
    # Calculate kernel density estimate at data locations if necessary
    if c is None:
        c = Parameters['plotting.scatter_plot.c']
    kde = False
    if isinstance(c, str):
        if c.lower()[:3] == 'kde':
            pval = c.lower()[3:]
            # Points are colored based on KDE
            if logy:
                ykde = yplot.copy()
                ykde[ykde <= 0] = Parameters['plotting.log_lowerval']
                ykde = np.log(ykde)
            else:
                ykde = yplot
            if logx:
                xkde = xplot.copy()
                xkde[xkde <= 0] = Parameters['plotting.log_lowerval']
                xkde = np.log(xkde)
            else:
                xkde = xplot
            xy = np.stack((xkde, ykde), axis=1)
            kde = gaussian_kde(xy.T)
            c = kde.evaluate(xy.T)
            c = (c - min(c)) / (max(c) - min(c))
            if len(pval) > 0:
                try:
                    if pval.startswith('p'):
                        ipval = int(pval.lower()[1:])
                    else:
                        ipval = int(pval.lower())
                    assert (ipval <= 100) and (ipval >= 1)
                    ipval -= 1
                except ValueError:
                    raise ValueError('Could not interpret {} as a kde percentile!'.
                                     format(pval.lower()))
                except AssertionError:
                    raise ValueError('kde percentiles must be 1 <= p <= 100 ')
                cdfx, cdfy = gs.cdf(c, bins=101)
                clipval = np.interp(ipval / 100, cdfy, cdfx)
                c[c > clipval] = clipval
            kde = True
        else:
            cbar = False
    # Draw parameters from Parameters if necessary
    if s is None:
        s = Parameters['plotting.scatter_plot.s']
    if alpha is None:
        alpha = Parameters['plotting.scatter_plot.alpha']
    if stat_blk is None:
        stat_blk = Parameters['plotting.scatter_plot.stat_blk']
    if roundstats is None:
        roundstats = Parameters['plotting.roundstats']
    if sigfigs is None:
        sigfigs = Parameters['plotting.sigfigs']
    # Set-up some parameters
    if len(c) != xplot.shape[0]:
        cmap = False
    else:
        if cmap is None:
            cmap = Parameters['plotting.scatter_plot.cmap']
    if cmap is not False:
        clim, ticklocs, ticklabels = gs.get_contcbarargs(c, sigfigs, clim)
    if clim is None:
        clim = (None, None)
    # Set-up plot if no axis is supplied using the ImageGrid method if required or the regular way
    cax = None
    fig, ax, cax = gs.setup_plot(ax, cax=cax, cbar=cbar, figsize=figsize)
    # Scatter - let Matplotlib use the default size/color if None
    if s is None:
        if c is None:
            plot = ax.scatter(xplot, yplot, alpha=alpha, label=label, cmap=cmap,
                              vmin=clim[0], vmax=clim[1], **kwargs)
        else:
            plot = ax.scatter(xplot, yplot, c=c, alpha=alpha, label=label, cmap=cmap,
                              vmin=clim[0], vmax=clim[1], **kwargs)
    else:
        if c is None:
            plot = ax.scatter(xplot, yplot, s=s, alpha=alpha, label=label, cmap=cmap,
                              vmin=clim[0], vmax=clim[1], **kwargs)
        else:
            plot = ax.scatter(xplot, yplot, s=s, c=c, alpha=alpha, label=label, cmap=cmap,
                              vmin=clim[0], vmax=clim[1], **kwargs)
    # Setup the colorbar if required
    if cbar:
        if kde:
            if clim[0] is not None and clim[1] is not None:
                ticklocs = np.linspace(clim[0], clim[1], 3)
            else:
                ticklocs = [0, 0.5, 1]
            ticklabels = ['Low', 'Med.', 'High']
            cbar_label = 'Kernel Density Estimate'
        cbar = fig.colorbar(plot, cax=cax, ticks=ticklocs)
        # Configure the color bar
        cbar.ax.set_yticklabels(ticklabels, ha='left')
        cbar.ax.tick_params(axis='y', pad=2)
        if cbar_label is not None:
            cbar.set_label(cbar_label, ha='center', va='top', labelpad=2)
    # Set the axis extents
    if xlim is None:
        xlim = (np.min(x), np.max(x))
    if ylim is None:
        ylim = (np.min(y), np.max(y))
    if logx and xlim[0] <= 0:
        if xlim[0] == 0:
            xlim = [Parameters['plotting.log_lowerval'], ylim[1]]
        else:
            raise ValueError('ERROR: invalid clim for a log x-axis!')
    if logy and ylim[0] <= 0:
        if ylim[0] == 0:
            ylim = [Parameters['plotting.log_lowerval'], ylim[1]]
        else:
            raise ValueError('ERROR: invalid clim for a log y-axis!')
    # Set the formatting attributes
    gs.format_plot(ax, xlabel, ylabel, title, grid, axis_xy, xlim, ylim, logx, logy)
    # Setup the correlation
    if stat_blk:
        stats = ['pearson', 'spearmanr', 'count', 'noweightflag']
        # Error checking and conversion to a list of stats
        if isinstance(stat_blk, str):
            if stat_blk == 'all':
                stat_blk = stats[:-1]
            else:
                stat_blk = [stat_blk]
        elif isinstance(stat_blk, tuple):
            stat_blk = list(stat_blk)
        if isinstance(stat_blk, list):
            for stat in stat_blk:
                if stat not in stats:
                    raise ValueError('invalid stat_blk')
        else:
            raise ValueError('invalid stat_blk')
        # Build the txtstats
        txtstats = ''
        if 'count' in stat_blk:
            txtstats += r'$n = $'+str(x.shape[0])
        if 'pearson' in stat_blk:
            corr = gs.weighted_correlation(x, y, wt)
            if roundstats:
                corr = round(corr, sigfigs)
            else:
                corr = gs.round_sigfig(corr, sigfigs)
            txtstats += '\n'+r'$\rho = $'+str(corr)
        if 'spearmanr' in stat_blk:
            corr = gs.weighted_correlation_rank(x, y, wt)
            if roundstats:
                corr = round(corr, sigfigs)
            else:
                corr = gs.round_sigfig(corr, sigfigs)
            txtstats += '\n'+r'$\rho_s = $'+str(corr)
        # Note if weights were used
        if len(np.unique(wt)) > 1 and 'noweightflag' not in stat_blk:
            txtstats = txtstats + '\n\nweights used'
        # Sort the location and font size
        if stat_xy is None:
            stat_xy = Parameters['plotting.scatter_plot.stat_xy']
        if stat_ha is None:
            stat_ha = Parameters['plotting.stat_ha']
        if stat_xy[1] > 0.5:
            va = 'top'
        else:
            va = 'bottom'
        stat_fontsize = _set_stat_fontsize(stat_fontsize)
        # Draw to plot
        ax.text(stat_xy[0], stat_xy[1], txtstats, va=va, ha=stat_ha, transform=ax.transAxes,
                fontsize=stat_fontsize, linespacing=0.8)

    # Handle dictionary defaults
    if out_kws is None:
        out_kws = dict()

    if output_file or ('pdfpages' in out_kws):
        gs.export_image(output_file, **out_kws)
        
    if return_plot:
        return ax, plot
    else:
        return ax
Exemplo n.º 5
0
def histogram_plot(data,
                   var=None,
                   weights=None,
                   cat=None,
                   catdict=None,
                   bins=None,
                   icdf=False,
                   lower=None,
                   upper=None,
                   ax=None,
                   figsize=None,
                   xlim=None,
                   ylim=None,
                   title=None,
                   xlabel=None,
                   stat_blk=None,
                   stat_xy=None,
                   stat_ha=None,
                   roundstats=None,
                   sigfigs=None,
                   color=None,
                   edgecolor=None,
                   edgeweights=None,
                   grid=None,
                   axis_xy=None,
                   label_count=False,
                   rotateticks=None,
                   plot_style=None,
                   custom_style=None,
                   output_file=None,
                   out_kws=None,
                   stat_fontsize=None,
                   stat_linespacing=None,
                   logx=False,
                   **kwargs):
    """
    Generates a matplotlib style histogram with summary statistics. Trimming is now only applied
    to NaN values (Pygeostat null standard).

    The only required required parameter is ``data``. If ``xlabel`` is left to its default value of
    ``None`` and the input data is contained in a pandas dataframe or series, the column
    information will be used to label the x-axis.

    Two statistics block sets are available: ``'all'`` and the default ``'minimal'``. The
    statistics block can be customized to a user defined list and order. Available statistics are
    as follows:

    >>> ['count', 'mean', 'stdev', 'cvar', 'max', 'upquart', 'median', 'lowquart', 'min',
    ...  'p10', 'p90']

    The way in which the values within the statistics block are rounded and displayed can be
    controlled using the parameters ``roundstats`` and ``sigfigs``.

    Please review the documentation of the :func:`gs.set_style()
    <pygeostat.plotting.set_style.set_style>` and :func:`gs.export_image()
    <pygeostat.plotting.export_image.export_image>` functions for details on their parameters so that
    their use in this function can be understood.

    Parameters:
        data (np.ndarray, pd.DataFrame/Series, or gs.DataFile): data array, which must be 1D
            unless var is provided. The exception being a DataFile, if data.variables
            is a single name.
        var (str): name of the variable in data, which is required if data is not 1D.
        weights (np.ndarray, pd.DataFrame/Series, or gs.DataFile or str): 1D array of declustering
             weights for the data. Alternatively the declustering weights name in var. If data
             is a DataFile, it may be string in data.columns, or True to use data.weights
             (if data.weights is not None).
        cat (bool or str): either a cat column in data.data, or if True uses data.cat if data.cat
            is not None
        catdict (dict or bool): overrides bins. If a categorical variable is being plotted, provide
            a dictionary where keys are numeric (categorical codes) and values are their associated
            labels (categorical names). The bins will be set so that the left edge (and associated
            label) of each bar is inclusive to each category. May also be set to True, if data is
            a DataFile and data.catdict is initialized.
        bins (int or list): Number of bins to use, or a list of bins
        icdf (bool): Indicator to plot a CDF or not
        lower (float): Lower limit for histogram
        upper (float): Upper limit for histogram
        ax (mpl.axis): Matplotlib axis to plot the figure
        figsize (tuple): Figure size (width, height)
        xlim (float tuple): Minimum and maximum limits of data along the x axis
        ylim (float tuple): Minimum and maximum limits of data along the y axis
        title (str): Title for the plot
        xlabel (str): X-axis label
        stat_blk (bool): Indicate if statistics are plotted or not
        stat_xy (float tuple): X, Y coordinates of the annotated statistics in figure
            space. Based on Parameters['plotting.histogram_plot.stat_xy'] if a histogram and
            Parameters['plotting.histogram_plot.stat_xy'] if a CDF, which defaults to the top right when
            a PDF is plotted and the bottom right if a CDF is plotted.
        stat_ha (str): Horizontal alignment parameter for the annotated statistics. Can be
            ``'right'``, ``'left'``, or ``'center'``. If None, based on
            Parameters['plotting.stat_ha']
        stat_fontsize (float): the fontsize for the statistics block. If None, based on
            Parameters['plotting.stat_fontsize']. If less than 1, it is the fraction of the
            matplotlib.rcParams['font.size']. If greater than 1, it the absolute font size.
        roundstats (bool): Indicate if the statistics should be rounded to the number of digits or
            to a number of significant figures (e.g., 0.000 vs. 1.14e-5). The number of digits or
            figures used is set by the parameter ``sigfigs``. sigfigs (int): Number of significant
            figures or number of digits (depending on ``roundstats``) to display for the float
            statistics
        color (str or int or dict): Any permissible matplotlib color or a integer which is used to draw
            a color from the pygeostat color pallet ``pallet_pastel``> May also be a dictionary of colors,
            which are used for each bar (useful for categories). colors.keys() must align with bins[:-1]
            if a dictionary is passed. Drawn from Parameters['plotting.cmap_cat'] if catdict is used
            and their keys align.
        edgecolor (str): Any permissible matplotlib color for the edge of a histogram bar
        grid(bool): plots the major grid lines if True. Based on Parameters['plotting.grid']
            if None.
        axis_xy (bool): converts the axis to GSLIB-style axis visibility (only left and bottom
            visible) if axis_xy is True. Based on Parameters['plotting.axis_xy'] if None.
        label_count (bool): label the number of samples found for each category in catdict. Does
            nothing if no catdict is found
        rotateticks (bool tuple): Indicate if the axis tick labels should be rotated (x, y)
        plot_style (str): Use a predefined set of matplotlib plotting parameters as specified by
            :class:`gs.GridDef <pygeostat.data.grid_definition.GridDef>`. Use ``False`` or ``None``
            to turn it off
        custom_style (dict): Alter some of the predefined parameters in the ``plot_style`` selected.
        output_file (str): Output figure file name and location
        out_kws (dict): Optional dictionary of permissible keyword arguments to pass to
            :func:`gs.export_image() <pygeostat.plotting.export_image.export_image>`
        **kwargs: Optional permissible keyword arguments to pass to either: (1) matplotlib's hist
            function if a PDF is plotted or (2) matplotlib's plot function if a CDF is plotted.

    Returns:
        ax (ax): matplotlib Axes object with the histogram

    **Examples:**

    A simple call:

    .. plot::

        import pygeostat as gs
        # load some data
        dfl = gs.ExampleData("point3d_ind_mv")
        # plot the histogram_plot
        gs.histogram_plot(dfl, var="Phi", bins=30)

    |

    Change the colour, number of significant figures displayed in the statistics, and pass some
    keyword arguments to matplotlibs hist function:

    .. plot::

        import pygeostat as gs
        # load some data
        dfl = gs.ExampleData("point3d_ind_mv")
        # plot the histogram_plot
        gs.histogram_plot(dfl, var="Phi", color='#c2e1e5', sigfigs=5, log=True, density=True)

    |

    Plot a CDF while also displaying all available statistics, which have been shifted up:

    .. plot::

        import pygeostat as gs
        # load some data
        dfl = gs.ExampleData("point3d_ind_mv")
        # plot the histogram_plot
        gs.histogram_plot(dfl, var="Phi", icdf=True, stat_blk='all', stat_xy=(1, 0.75))
        # Change the CDF line colour by grabbing the 3rd colour from the color pallet
        # ``cat_vibrant`` and increase its width by passing a keyword argument to matplotlib's
        # plot function. Also define a custom statistics block:
        gs.histogram_plot(dfl, var="Phi", icdf=True, color=3, lw=3.5, stat_blk=['count','upquart'])

    |

    Generate histograms of Phi considering the categories:

    .. plot::

        import pygeostat as gs
        # load some data
        dfl = gs.ExampleData("point3d_ind_mv")
        cats = [1, 2, 3, 4, 5]
        colors = gs.catcmapfromcontinuous("Spectral", 5).colors
        # build the required cat dictionaries
        dfl.catdict = {c: "RT {:02d}".format(c) for c in cats}
        colordict =  {c: colors[i] for i, c in enumerate(cats)}
        # plot the histogram_plot
        f, axs = plt.subplots(2, 1, figsize=(8, 6))
        for var, ax in zip(["Phi", "Sw"], axs):
            gs.histogram_plot(dfl, var=var, cat=True, color=colordict, bins=40, figsize=(8, 4), ax=ax,
                       xlabel=False, title=var)

    |

    Generate cdf subplots considering the categories:

    .. plot::

        import pygeostat as gs
        # load some data
        dfl = gs.ExampleData("point3d_ind_mv")
        cats = [1, 2, 3, 4, 5]
        colors = gs.catcmapfromcontinuous("Spectral", 5).colors
        # build the required cat dictionaries
        dfl.catdict = {c: "RT {:02d}".format(c) for c in cats}
        colordict =  {c: colors[i] for i, c in enumerate(cats)}
        # plot the histogram_plot
        f, axs = plt.subplots(2, 2, figsize=(12, 9))
        axs=axs.flatten()
        for var, ax in zip(dfl.variables, axs):
            gs.histogram_plot(dfl, var=var, icdf=True, cat=True, color=colordict, ax=ax)

    Recreate the `Proportion` class plot

    .. plot::

        import pygeostat as gs
        # load some data
        dfl = gs.ExampleData("point3d_ind_mv")
        cats = [1, 2, 3, 4, 5]
        colors = gs.catcmapfromcontinuous("Spectral", 5).colors
        # build the required cat dictionaries
        dfl.catdict = {c: "RT {:02d}".format(c) for c in cats}
        colordict =  {c: colors[i] for i, c in enumerate(cats)}
        # plot the histogram_plot
        ax = gs.histogram_plot(dfl, cat=True, color=colordict, figsize=(7, 4), rotateticks=(45, 0),
                        label_count=True)

    """
    import pygeostat as gs
    from .utils import format_plot, _set_stat_fontsize, _format_grid, _format_tick_labels, setup_plot, catcmapfromcontinuous
    from .cmaps import _cat_pastel_data, _cat_vibrant_data
    import copy
    # Now converting to a numpy array, as encountering some odd pandas performance, and there's
    # no major disadvantagve to application of a numpy in this context to my knowledge - RMB
    # If a list is passed convert it to a series so that trimming can take place
    # weights
    if isinstance(weights, str):
        if isinstance(data, pd.DataFrame) or isinstance(data, gs.DataFile):
            weights = data[weights]
    elif isinstance(weights, bool):
        if weights:
            if isinstance(data, gs.DataFile):
                if data.weights is None:
                    raise ValueError('weights=True but data.weights is None!')
                elif isinstance(data.weights, list):
                    raise ValueError(
                        'weights=True but data.weights is a list!')
                weights = data[data.weights].values
            else:
                raise ValueError(
                    'weights=True is only valid if data is a DataFile!')
        else:
            weights = None
    if isinstance(weights, pd.Series) or isinstance(weights, pd.DataFrame):
        weights = weights.values
    # cats for continuous histogram_plots
    if isinstance(cat, str):
        if isinstance(data, pd.DataFrame) or isinstance(data, gs.DataFile):
            cat = data[cat]
    elif isinstance(cat, bool):
        if cat:
            if isinstance(data, gs.DataFile):
                if data.cat is None:
                    raise ValueError('cat=True but data.cat is None!')
                cat = data[data.cat].values
                if catdict is None and data.catdict is None:
                    raise ValueError("pass a `catdict` when setting `cat`")
                else:
                    catdict = data.catdict
            else:
                raise ValueError(
                    'cat=True is only valid if data is a DataFile!')
        else:
            cat = None
    if isinstance(cat, pd.Series) or isinstance(cat, pd.DataFrame):
        cat = cat.values
    # Handle categorical dictionary
    if isinstance(catdict, bool):
        if catdict:
            if not isinstance(data, gs.DataFile):
                raise ValueError(
                    'catdict as a bool is only valid if data is a DataFile!')
            if data.catdict is None:
                raise ValueError(
                    'catdict as a bool is only valid if data is not None!')
            catdict = data.catdict
    # Variable
    # Handle data that is 2-D and/or a DataFile
    if isinstance(var, str):
        if isinstance(data, pd.DataFrame) or isinstance(data, gs.DataFile):
            if isinstance(cat, str):
                cat = data[cat]
            data = data[var]
        else:
            raise ValueError(
                'var as a string is only valid if data is a DataFile or DataFrame!'
            )
    elif isinstance(data, gs.DataFile):
        if isinstance(data.variables, str):
            data = data[data.variables]
        elif cat is not None:
            if isinstance(cat, str):
                data = data[cat]
            elif var is None and isinstance(cat, (np.ndarray, list)):
                data = cat
        elif len(data.columns) == 1:
            data = data.data
        else:
            raise ValueError(
                'Could not coerce data (DataFile) into a 1D dataset!')
    # Get the xlabel if possible before converting to a numpy array
    if isinstance(data, pd.Series) or isinstance(data, pd.DataFrame):
        if xlabel is None:
            xlabel = gs.get_label(data)
        data = data.values
    elif isinstance(data, list):
        data = np.array(data)
    if isinstance(cat, (pd.Series, pd.DataFrame)):
        cat = cat.values
    # Should be numpy by now...
    if data.ndim > 1:
        if data.shape[1] > 1:
            raise ValueError('Could not coerce data into a 1D dataset!')
        else:
            data = data.flatten()
    # Handle Null values if needed
    idx = np.isnan(data)
    nullcnt = np.sum(idx)
    if nullcnt > 0:
        data = data[~idx]
        if weights is not None:
            weights = weights[~idx]
        if cat is not None:
            cat = cat[~idx]
    # Handle dictionary defaults
    if out_kws is None:
        out_kws = dict()
    # Set-up plot if no axis is supplied
    _, ax, _ = setup_plot(ax, figsize=figsize, aspect=False)
    # Infer some default parameters
    if weights is None:
        weights = np.ones(len(data)) / len(data)
    else:
        weights = weights / np.sum(weights)
    # Some quick error checks
    assert (np.all(weights) >= 0.0), 'weights less than 0 not valid'
    # Categories
    if isinstance(catdict, dict) and var is None:
        if not all([isinstance(float(i), float) for i in catdict.keys()]):
            raise ValueError(
                'if catdict is dict., all keys should be an int/float!')
        # The bins are set to begin at the start of each category
        # bins go from 0.5 to (icat + 1) + 0.5
        # label is centered at (icat + 1)
        bins = np.arange(len(catdict) + 1) + 0.5
    if color is None and isinstance(catdict, dict):
        # Color each bin by the category color?
        if isinstance(Parameters['plotting.cmap_cat'], dict):
            temp = Parameters['plotting.cmap_cat']
            if list(sorted(temp.keys())) == list(sorted(catdict.keys())):
                color = temp
        else:
            color = catcmapfromcontinuous(Parameters["plotting.cmap"],
                                          len(catdict)).colors
    if isinstance(color, dict):
        if list(sorted(color.keys())) != list(sorted(catdict.keys())):
            raise ValueError(('if color is a dictionary, keys must align with '
                              'bins[:-1]! Consider using a single color.'))
        temp = color
        color = []
        for _, v in sorted(temp.items()):
            color.append(v)
    # Color setup
    if isinstance(color, int):
        # Grab a color from ``cat_vibrant`` if an integer is passed
        color = _cat_pastel_data[color % len(_cat_vibrant_data)]
    if not icdf:
        if color is None:
            color = Parameters['plotting.histogram_plot.facecolor']
        if edgecolor is None:
            edgecolor = Parameters['plotting.histogram_plot.edgecolor']
        if edgeweights is None:
            if "lw" in kwargs:
                edgeweights = kwargs.pop("lw")
            else:
                edgeweights = Parameters["plotting.histogram_plot.edgeweight"]
    else:
        if color is None and icdf:
            color = Parameters['plotting.histogram_plot.cdfcolor']
    plotdata = copy.deepcopy(data)
    plotweights = copy.deepcopy(weights)
    if xlim is not None:
        plotdata[data < xlim[0]] = xlim[0]
        plotdata[data > xlim[1]] = xlim[1]
    # Main plotting
    if icdf:

        def singlecdf(ax,
                      data,
                      weights,
                      lower,
                      upper,
                      bins,
                      color,
                      label=None,
                      **kwargs):
            """ local function to plot a single cdf """
            cdf_x, cdfvals = gs.cdf(data,
                                    weights=weights,
                                    lower=lower,
                                    upper=upper,
                                    bins=bins)
            # Matplotlib is a memory hog if to many points are used. Limit the number of points the CDF
            # is build with to 1000. The tails are given extra attention to make sure they are defined
            # nicely.
            if len(cdf_x) > 1000:
                cdfinterp = scipy.interpolate.interp1d(x=cdfvals, y=cdf_x)
                cdfvals = np.concatenate([
                    np.arange(cdfvals.min(), 0.1, 0.001),
                    np.arange(0.1, 0.9, 0.01),
                    np.arange(0.9, cdfvals.max(), 0.001)
                ])
                cdf_x = []
                for val in cdfvals:
                    cdf_x.append(cdfinterp(val))
                cdf_x = np.array(cdf_x)
            fig = ax.plot(cdf_x, cdfvals, color=color, label=label, **kwargs)
            return fig

        if catdict is not None:
            if var is not None:
                stat_blk = False
                for icat, c in enumerate(catdict):
                    clr = color[icat]
                    catidx = cat == c
                    fig = singlecdf(ax,
                                    plotdata[catidx],
                                    plotweights[catidx],
                                    lower,
                                    upper,
                                    bins,
                                    clr,
                                    label=catdict[c],
                                    **kwargs)
            else:
                raise ValueError(
                    "`icdf=True` and `catdict` only makes sense with a `var` defined"
                )
        else:
            fig = singlecdf(ax, plotdata, plotweights, lower, upper, bins,
                            color, **kwargs)
        if ylim is None:
            ylim = (0, 1.0)
    else:
        if bins is None:
            bins = Parameters['plotting.histogram_plot.histbins']
        label = kwargs.pop("label", None)
        if bins is None:
            if len(plotdata) < 200:
                bins = 20
            elif len(plotdata) < 500:
                bins = 25
            else:
                bins = 30
        if logx:
            if catdict is not None:
                raise ValueError('Cannot have logx with catdict!')
            if xlim is None:
                minv = np.log10(max(plotdata.min(), 1e-10))
                maxv = np.log10(plotdata.max())
            else:
                minv = np.log10(max(xlim[0], 1e-10))
                maxv = np.log10(xlim[1])
            if np.isnan([minv, maxv]).any():
                raise ValueError(
                    'ERROR converting your data to log base! are there negatives?'
                )
            bins = np.logspace(minv, maxv, bins)
        if catdict is not None:
            if var is None:
                for icat, cat in enumerate(catdict):
                    plotdata[data == cat] = icat + 1
                histclr = None
            else:
                # generate lists of data per cat
                plotdata = [plotdata[cat == c] for c in catdict]
                plotweights = [weights[cat == c] for c in catdict]
                label = list(catdict.values())
                histtype = kwargs.pop("histtype", "stepfilled")
                stat_blk = False
                if "stacked" not in kwargs:
                    kwargs["stacked"] = True
                histclr = color
        histtype = kwargs.pop("histtype", "bar")
        if not isinstance(color, list):
            ax.hist(plotdata,
                    bins,
                    weights=plotweights,
                    color=color,
                    edgecolor=edgecolor,
                    histtype=histtype,
                    label=label,
                    lw=edgeweights,
                    **kwargs)
        else:
            _, _, patches = ax.hist(plotdata,
                                    bins,
                                    weights=plotweights,
                                    histtype=histtype,
                                    color=histclr,
                                    edgecolor=edgecolor,
                                    label=label,
                                    lw=edgeweights,
                                    **kwargs)
            try:
                for patch, clr in zip(patches, color):
                    patch.set_facecolor(clr)
            except (AttributeError, ValueError):
                pass
        if catdict is not None and label_count:
            nd = len(data)
            for icat, cat in enumerate(catdict):
                count = np.count_nonzero(data == cat)
                pcat = (weights * (data == cat).astype(float)).sum()
                ax.text(icat + 1, pcat, count, ha="center", va="bottom")
    # Summary stats
    if stat_blk is None:
        stat_blk = Parameters['plotting.histogram_plot.stat_blk']
    if stat_xy is None:
        if icdf:
            stat_xy = Parameters['plotting.histogram_plot.stat_xy_cdf']
        else:
            stat_xy = Parameters['plotting.histogram_plot.stat_xy']
    if stat_blk:
        if sigfigs is None:
            sigfigs = Parameters['plotting.sigfigs']
        if roundstats is None:
            roundstats = Parameters['plotting.roundstats']
        if stat_ha is None:
            stat_ha = Parameters['plotting.stat_ha']
        if stat_linespacing is None:
            stat_linespacing = Parameters['plotting.stat_linespacing']
        if stat_linespacing is None:
            stat_linespacing = 1.0
        # Force no bins and upper/lower for median
        cdf_x, cdfvals = gs.cdf(data, weights=weights)
        # Currently defined statistics, possible to add more quite simply
        if np.mean(data) == 0:
            cdata = float("nan")
        elif roundstats:
            cdata = round((np.std(data) / np.mean(data)), sigfigs)
        else:
            cdata = gs.round_sigfig((np.std(data) / np.mean(data)), sigfigs)
        if roundstats:
            mean = round(gs.weighted_mean(data, weights), sigfigs)
            median = round(gs.percentile_from_cdf(cdf_x, cdfvals, 50.0),
                           sigfigs)
            stdev = round(np.sqrt(gs.weighted_variance(data, weights)),
                          sigfigs)
            minval = round(np.min(data), sigfigs)
            maxval = round(np.max(data), sigfigs)
            upquart = round(np.percentile(data, 75), sigfigs)
            lowquart = round(np.percentile(data, 25), sigfigs)
            p10 = round(np.percentile(data, 10), sigfigs)
            p90 = round(np.percentile(data, 90), sigfigs)
        else:
            mean = gs.round_sigfig(gs.weighted_mean(data, weights), sigfigs)
            median = gs.round_sigfig(
                gs.percentile_from_cdf(cdf_x, cdfvals, 50.0), sigfigs)
            stdev = gs.round_sigfig(
                np.sqrt(gs.weighted_variance(data, weights)), sigfigs)
            minval = gs.round_sigfig(np.min(data), sigfigs)
            maxval = gs.round_sigfig(np.max(data), sigfigs)
            upquart = gs.round_sigfig(np.percentile(data, 75), sigfigs)
            lowquart = gs.round_sigfig(np.percentile(data, 25), sigfigs)
            p10 = gs.round_sigfig(np.percentile(data, 10), sigfigs)
            p90 = gs.round_sigfig(np.percentile(data, 90), sigfigs)
        statistics = {
            'mean': (r'$m = %g$' % mean),
            'median': (r'$x_{{50}} = %g$' % median),
            'count': ('$n = %i$' % len(data)),
            'count_trimmed': ('$n_{trim} = %i$' % nullcnt),
            'stdev': (r'$\sigma = %g$' % stdev),
            'cvar': ('$CV = %g$' % cdata),
            'min': ('$x_{{min}} = %g$' % minval),
            'max': ('$x_{{max}} = %g$' % maxval),
            'upquart': ('$x_{{75}} = %g$' % upquart),
            'lowquart': ('$x_{{25}} = %g$' % lowquart),
            'p10': ('$x_{{10}} = %g$' % p10),
            'p90': ('$x_{{90}} = %g$' % p90)
        }
        # Default statistic sets
        if stat_blk == 'varlabel' and 'label' in kwargs:
            statistics['varlabel'] = kwargs['label']
        statsets = {
            'minimal': ['count', 'mean', 'median', 'stdev'],
            'all': [
                'count', 'mean', 'stdev', 'cvar', 'max', 'upquart', 'median',
                'lowquart', 'min'
            ],
            'varlabel': [
                'varlabel', 'count', 'mean', 'stdev', 'cvar', 'max', 'upquart',
                'median', 'lowquart', 'min'
            ],
            'none':
            None
        }
        # Use a default statistic set
        if isinstance(stat_blk, bool) and stat_blk:
            stat_blk = 'all'
        if isinstance(stat_blk, str):
            if stat_blk in statsets:
                stat_blk = statsets[stat_blk]
            else:
                print('WARNING: stats value of: "' + stat_blk +
                      '" does not exist - '
                      'default to no stats')
                stat_blk = None
        # Use a supplied statistic set, but check for bad ones
        else:
            badstats = [s for s in stat_blk if s not in statistics]
            stat_blk = [s for s in stat_blk if s in statistics]
            for badstat in badstats:
                print('WARNING: stats value of: "' + badstat +
                      '" does not exist - '
                      'It was removed from summary statistics list')
        # Form the stats string
        if stat_blk:
            if nullcnt != 0:
                stat_blk.insert(stat_blk.index('count') + 1, 'count_trimmed')
            stat_blk = [statistics[s] for s in stat_blk]
            txtstats = '\n'.join(stat_blk)
            if len(np.unique(weights)) > 1:
                txtstats = txtstats + '\n\nweights used'
            if stat_xy[1] > 0.5:
                va = 'top'
            else:
                va = 'bottom'
            # Set the stat_fontsize
            stat_fontsize = _set_stat_fontsize(stat_fontsize)
            ax.text(stat_xy[0],
                    stat_xy[1],
                    txtstats,
                    va=va,
                    ha=stat_ha,
                    transform=ax.transAxes,
                    fontsize=stat_fontsize,
                    linespacing=stat_linespacing)
    # Label as required
    if icdf:
        ylabel = 'Cumulative Distribution Function'
    elif 'density' in kwargs:
        ylabel = 'Probability Density Function (PDF)'
    else:
        ylabel = 'Frequency'
    ax = format_plot(ax,
                     xlabel,
                     ylabel,
                     title,
                     axis_xy=axis_xy,
                     xlim=xlim,
                     ylim=ylim,
                     logx=logx)
    if catdict is not None and var is None:
        ticlocs = [i + 1 for i in range(len(catdict.keys()))]
        ax.set_xticks(ticlocs)
        ax.set_xticklabels(catdict.values())
        ax.set_xlim(0.25, len(catdict) + 0.75)
    elif catdict is not None and var is not None:
        ax.legend()
    _format_tick_labels(ax, rotateticks)
    # format_plot doesn't handle some specialized axis_xy and grid requirements
    # for histogram_plot...
    if icdf:
        # Ensure that we have top spline, in case it was removed above
        ax.spines['top'].set_visible(True)
        _format_grid(ax, grid, below=False)
    else:
        # The grid should be below for a histogram
        _format_grid(ax, grid, below=True)
    # Export figure
    if output_file or ('pdfpages' in out_kws):
        gs.export_image(output_file, **out_kws)
    return ax