Пример #1
0
    def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=None):
        """
        Check expected number of axes is drawn in expected layout

        Parameters
        ----------
        axes : matplotlib Axes object, or its list-like
        axes_num : number
            expected number of axes. Unnecessary axes should be set to
            invisible.
        layout : tuple
            expected layout, (expected number of rows , columns)
        figsize : tuple
            expected figsize. default is matplotlib default
        """
        from pandas.plotting._matplotlib.tools import flatten_axes

        if figsize is None:
            figsize = self.default_figsize
        visible_axes = self._flatten_visible(axes)

        if axes_num is not None:
            assert len(visible_axes) == axes_num
            for ax in visible_axes:
                # check something drawn on visible axes
                assert len(ax.get_children()) > 0

        if layout is not None:
            result = self._get_axes_layout(flatten_axes(axes))
            assert result == layout

        tm.assert_numpy_array_equal(
            visible_axes[0].figure.get_size_inches(),
            np.array(figsize, dtype=np.float64),
        )
Пример #2
0
def _grouped_plot_by_column(
    plotf,
    data,
    columns=None,
    by=None,
    numeric_only=True,
    grid=False,
    figsize=None,
    ax=None,
    layout=None,
    return_type=None,
    **kwargs,
):
    grouped = data.groupby(by)
    if columns is None:
        if not isinstance(by, (list, tuple)):
            by = [by]
        columns = data._get_numeric_data().columns.difference(by)
    naxes = len(columns)
    fig, axes = create_subplots(
        naxes=naxes,
        sharex=kwargs.pop("sharex", True),
        sharey=kwargs.pop("sharey", True),
        figsize=figsize,
        ax=ax,
        layout=layout,
    )

    _axes = flatten_axes(axes)

    # GH 45465: move the "by" label based on "vert"
    xlabel, ylabel = kwargs.pop("xlabel", None), kwargs.pop("ylabel", None)
    if kwargs.get("vert", True):
        xlabel = xlabel or by
    else:
        ylabel = ylabel or by

    ax_values = []

    for i, col in enumerate(columns):
        ax = _axes[i]
        gp_col = grouped[col]
        keys, values = zip(*gp_col)
        re_plotf = plotf(keys, values, ax, xlabel=xlabel, ylabel=ylabel, **kwargs)
        ax.set_title(col)
        ax_values.append(re_plotf)
        ax.grid(grid)

    result = pd.Series(ax_values, index=columns)

    # Return axes in multiplot case, maybe revisit later # 985
    if return_type is None:
        result = axes

    byline = by[0] if len(by) == 1 else by
    fig.suptitle(f"Boxplot grouped by {byline}")
    maybe_adjust_figure(fig, bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)

    return result
Пример #3
0
def _grouped_plot_by_column(
    plotf,
    data,
    columns=None,
    by=None,
    numeric_only=True,
    grid=False,
    figsize=None,
    ax=None,
    layout=None,
    return_type=None,
    **kwargs,
):
    grouped = data.groupby(by)
    if columns is None:
        if not isinstance(by, (list, tuple)):
            by = [by]
        columns = data._get_numeric_data().columns.difference(by)
    naxes = len(columns)
    fig, axes = create_subplots(naxes=naxes,
                                sharex=True,
                                sharey=True,
                                figsize=figsize,
                                ax=ax,
                                layout=layout)

    _axes = flatten_axes(axes)

    ax_values = []

    for i, col in enumerate(columns):
        ax = _axes[i]
        gp_col = grouped[col]
        keys, values = zip(*gp_col)
        re_plotf = plotf(keys, values, ax, **kwargs)
        ax.set_title(col)
        ax.set_xlabel(pprint_thing(by))
        ax_values.append(re_plotf)
        ax.grid(grid)

    result = pd.Series(ax_values, index=columns)

    # Return axes in multiplot case, maybe revisit later # 985
    if return_type is None:
        result = axes

    byline = by[0] if len(by) == 1 else by
    fig.suptitle(f"Boxplot grouped by {byline}")
    fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)

    return result
Пример #4
0
    def _flatten_visible(self, axes):
        """
        Flatten axes, and filter only visible

        Parameters
        ----------
        axes : matplotlib Axes object, or its list-like

        """
        from pandas.plotting._matplotlib.tools import flatten_axes

        axes = flatten_axes(axes)
        axes = [ax for ax in axes if ax.get_visible()]
        return axes
Пример #5
0
def _grouped_plot(
    plotf,
    data,
    column=None,
    by=None,
    numeric_only=True,
    figsize=None,
    sharex=True,
    sharey=True,
    layout=None,
    rot=0,
    ax=None,
    **kwargs,
):

    if figsize == "default":
        # allowed to specify mpl default with 'default'
        raise ValueError("figsize='default' is no longer supported. "
                         "Specify figure size by tuple instead")

    grouped = data.groupby(by)
    if column is not None:
        grouped = grouped[column]

    naxes = len(grouped)
    fig, axes = create_subplots(naxes=naxes,
                                figsize=figsize,
                                sharex=sharex,
                                sharey=sharey,
                                ax=ax,
                                layout=layout)

    _axes = flatten_axes(axes)

    for i, (key, group) in enumerate(grouped):
        ax = _axes[i]
        if numeric_only and isinstance(group, ABCDataFrame):
            group = group._get_numeric_data()
        plotf(group, ax, **kwargs)
        ax.set_title(pprint_thing(key))

    return fig, axes
Пример #6
0
def plot(
    plot_func,
    ax=None,
    figsize=None,
    layout=None,
    title=None,
    xlabel=None,
    ylabel=None,
    grid: bool = True,
    legend: bool = False,
    colorbar: bool = False,
):
    # TODO: dynamic axes with column = None as default
    naxes = 1
    fig, axes = pd_mpl_tools.create_subplots(
        naxes=naxes,
        ax=ax,
        figsize=figsize,
        layout=layout,
        # sharex=sharex,
        # sharey=sharey,
        squeeze=False,
    )
    _axes = pd_mpl_tools.flatten_axes(axes)

    # Make the target plot
    ax = _axes[0]
    p = plot_func(ax)

    # Other MPL things
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.grid(grid)
    if legend:
        ax.legend()
    if colorbar:
        fig.colorbar(p, ax=ax)

    return _axes
Пример #7
0
def boxplot_frame_groupby(
    grouped,
    subplots=True,
    column=None,
    fontsize=None,
    rot=0,
    grid=True,
    ax=None,
    figsize=None,
    layout=None,
    sharex=False,
    sharey=True,
    **kwds,
):
    if subplots is True:
        naxes = len(grouped)
        fig, axes = create_subplots(
            naxes=naxes,
            squeeze=False,
            ax=ax,
            sharex=sharex,
            sharey=sharey,
            figsize=figsize,
            layout=layout,
        )
        axes = flatten_axes(axes)

        ret = pd.Series(dtype=object)

        for (key, group), ax in zip(grouped, axes):
            d = group.boxplot(
                ax=ax, column=column, fontsize=fontsize, rot=rot, grid=grid, **kwds
            )
            ax.set_title(pprint_thing(key))
            ret.loc[key] = d
        maybe_adjust_figure(fig, bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
    else:
        keys, frames = zip(*grouped)
        if grouped.axis == 0:
            df = pd.concat(frames, keys=keys, axis=1)
        else:
            if len(frames) > 1:
                df = frames[0].join(frames[1::])
            else:
                df = frames[0]

        # GH 16748, DataFrameGroupby fails when subplots=False and `column` argument
        # is assigned, and in this case, since `df` here becomes MI after groupby,
        # so we need to couple the keys (grouped values) and column (original df
        # column) together to search for subset to plot
        if column is not None:
            column = com.convert_to_list_like(column)
            multi_key = pd.MultiIndex.from_product([keys, column])
            column = list(multi_key.values)
        ret = df.boxplot(
            column=column,
            fontsize=fontsize,
            rot=rot,
            grid=grid,
            ax=ax,
            figsize=figsize,
            layout=layout,
            **kwds,
        )
    return ret
Пример #8
0
def _joyplot(data,
             grid=False,
             labels=None, sublabels=None,
             xlabels=True,
             xlabelsize=None, xrot=None,
             ylabelsize=None, yrot=None,
             ax=None, figsize=None,
             hist=False, bins=10,
             fade=False,
             xlim=None, ylim='max',
             fill=True, linecolor=None,
             overlap=1, background=None,
             range_style='all', x_range=None, tails=0.2,
             title=None,
             legend=False, loc="upper right",
             colormap=None, color=None, last_axis=True,
             **kwargs):
    """
    Internal method.
    Draw a joyplot from an appropriately nested collection of lists
    using matplotlib and pandas.

    Parameters
    ----------
    data : DataFrame, Series or nested collection
    grid : boolean, default True
        Whether to show axis grid lines
    labels : boolean or list, default True.
        If list, must be the same size of the de
    xlabelsize : int, default None
        If specified changes the x-axis label size
    xrot : float, default None
        rotation of x axis labels
    ylabelsize : int, default None
        If specified changes the y-axis label size
    yrot : float, default None
        rotation of y axis labels
    ax : matplotlib axes object, default None
    figsize : tuple
        The size of the figure to create in inches by default
    hist : boolean, default False
    bins : integer, default 10
        Number of histogram bins to be used
    kwarg : other plotting keyword arguments
        To be passed to hist/kde plot function
    """

    if fill is True and linecolor is None:
        linecolor = "k"

    if sublabels is None:
        legend = False

    def _get_color(i, num_axes, j, num_subgroups):
        if isinstance(color, list):
            try:
                return color[i]
            except IndexError:
                pass
        elif color is not None:
            return color
        elif isinstance(colormap, list):
            return colormap[j](i/num_axes)
        elif color is None and colormap is None:
            return plt.rcParams['axes.prop_cycle'].by_key()['color'][j]
        else:
            return colormap(i/num_axes)

    ygrid = (grid is True or grid == 'y' or grid == 'both')
    xgrid = (grid is True or grid == 'x' or grid == 'both')

    num_axes = len(data)

    if x_range is None:
        global_x_range = _x_range([v for g in data for sg in g for v in sg])
    else:
        global_x_range = _x_range(x_range, 0.0)
    global_x_min, global_x_max = min(global_x_range), max(global_x_range)

    # Each plot will have its own axis
    fig, axes = create_subplots(naxes=num_axes, ax=ax, squeeze=False,
                          sharex=False, sharey=False, figsize=figsize,
                          layout_type='vertical')
    _axes = flatten_axes(axes)

    # The legend must be drawn in the last axis if we want it at the bottom.
    if loc in (3, 4, 8) or 'lower' in str(loc):
        legend_axis = num_axes - 1
    else:
        legend_axis = 0

    # A couple of simple checks.
    if labels is not None:
        assert len(labels) == num_axes
    if sublabels is not None:
        assert all(len(g) == len(sublabels) for g in data)
    #if isinstance(color, list):
    #    assert all(len(g) == len(color) for g in data)
    if isinstance(colormap, list):
        assert all(len(g) == len(colormap) for g in data)

    for i, group in enumerate(data):
        a = _axes[i]
        group_zorder = i
        if fade:
            kwargs['alpha'] = _get_alpha(i, num_axes)

        num_subgroups = len(group)

        if hist:
            # matplotlib hist() already handles multiple subgroups in a histogram
            a.hist(group, label=sublabels, bins=bins, color=color,
                   range=[min(global_x_range), max(global_x_range)],
                   edgecolor=linecolor, zorder=group_zorder, **kwargs)

        else:
            for j, subgroup in enumerate(group):

                # Compute the x_range of the current plot
                if range_style == 'all':
                # All plots have the same range
                    x_range = global_x_range
                elif range_style == 'own':
                # Each plot has its own range
                    x_range = _x_range(subgroup, tails)
                elif range_style == 'group':
                # Each plot has a range that covers the whole group
                    x_range = _x_range(group, tails)
                elif isinstance(range_style, (list, np.ndarray)):
                # All plots have exactly the range passed as argument
                    x_range = _x_range(range_style, 0.0)
                else:
                    raise NotImplementedError("Unrecognized range style.")

                if sublabels is None:
                    sublabel = None
                else:
                    sublabel = sublabels[j]

                element_zorder = group_zorder + j/(num_subgroups+1)
                element_color = _get_color(i, num_axes, j, num_subgroups)

                plot_density(a, x_range, subgroup,
                             fill=fill, linecolor=linecolor, label=sublabel,
                             zorder=element_zorder, color=element_color,
                             bins=bins, **kwargs)


        # Setup the current axis: transparency, labels, spines.
        col_name = None if labels is None else labels[i]
        _setup_axis(a, global_x_range, col_name=col_name, grid=ygrid,
                ylabelsize=ylabelsize, yrot=yrot)

        # When needed, draw the legend
        if legend and i == legend_axis:
            a.legend(loc=loc)
            # Bypass alpha values, in case
            for p in a.get_legend().get_patches():
                p.set_facecolor(p.get_facecolor())
                p.set_alpha(1.0)
            for l in a.get_legend().get_lines():
                l.set_alpha(1.0)


    # Final adjustments

    # Set the y limit for the density plots.
    # Since the y range in the subplots can vary significantly,
    # different options are available.
    if ylim == 'max':
        # Set all yaxis limit to the same value (max range among all)
        max_ylim = max(a.get_ylim()[1] for a in _axes)
        min_ylim = min(a.get_ylim()[0] for a in _axes)
        for a in _axes:
            a.set_ylim([min_ylim - 0.1*(max_ylim-min_ylim), max_ylim])

    elif ylim == 'own':
        # Do nothing, each axis keeps its own ylim
        pass

    else:
        # Set all yaxis lim to the argument value ylim
        try:
            for a in _axes:
                a.set_ylim(ylim)
        except:
            print("Warning: the value of ylim must be either 'max', 'own', or a tuple of length 2. The value you provided has no effect.")
    if last_axis is True:
        # Compute a final axis, used to apply global settings
        last_axis = fig.add_subplot(1, 1, 1)

        # Background color
        if background is not None:
            last_axis.patch.set_facecolor(background)

        for side in ['top', 'bottom', 'left', 'right']:
            last_axis.spines[side].set_visible(_DEBUG)

        # This looks hacky, but all the axes share the x-axis,
        # so they have the same lims and ticks
        last_axis.set_xlim(_axes[0].get_xlim())
        if xlabels is True:
            last_axis.set_xticks(np.array(_axes[0].get_xticks()[1:-1]))
            for t in last_axis.get_xticklabels():
                t.set_visible(True)
                t.set_fontsize(xlabelsize)
                t.set_rotation(xrot)

            # If grid is enabled, do not allow xticks (they are ugly)
            if xgrid:
                last_axis.tick_params(axis='both', which='both',length=0)
        else:
            last_axis.xaxis.set_visible(False)

        last_axis.yaxis.set_visible(False)
        last_axis.grid(xgrid)


        # Last axis on the back
        last_axis.zorder = min(a.zorder for a in _axes) - 1
        _axes = list(_axes) + [last_axis]

    if title is not None:
        plt.title(title)


    # The magic overlap happens here.
    #h_pad = 5 + (- 5*(1 + overlap))
    #fig.tight_layout(h_pad=h_pad)

    return fig, _axes
Пример #9
0
def hist_frame(
    data,
    column=None,
    by=None,
    grid=True,
    xlabelsize=None,
    xrot=None,
    ylabelsize=None,
    yrot=None,
    ax=None,
    sharex=False,
    sharey=False,
    figsize=None,
    layout=None,
    bins=10,
    legend: bool = False,
    **kwds,
):
    if legend and "label" in kwds:
        raise ValueError("Cannot use both legend and label")
    if by is not None:
        axes = _grouped_hist(
            data,
            column=column,
            by=by,
            ax=ax,
            grid=grid,
            figsize=figsize,
            sharex=sharex,
            sharey=sharey,
            layout=layout,
            bins=bins,
            xlabelsize=xlabelsize,
            xrot=xrot,
            ylabelsize=ylabelsize,
            yrot=yrot,
            legend=legend,
            **kwds,
        )
        return axes

    if column is not None:
        if not isinstance(column, (list, np.ndarray, ABCIndexClass)):
            column = [column]
        data = data[column]
    # GH32590
    data = data.select_dtypes(include=(np.number, "datetime64", "datetimetz"),
                              exclude="timedelta")
    naxes = len(data.columns)

    if naxes == 0:
        raise ValueError(
            "hist method requires numerical or datetime columns, nothing to plot."
        )

    fig, axes = create_subplots(
        naxes=naxes,
        ax=ax,
        squeeze=False,
        sharex=sharex,
        sharey=sharey,
        figsize=figsize,
        layout=layout,
    )
    _axes = flatten_axes(axes)

    can_set_label = "label" not in kwds

    for i, col in enumerate(data.columns):
        ax = _axes[i]
        if legend and can_set_label:
            kwds["label"] = col
        ax.hist(data[col].dropna().values, bins=bins, **kwds)
        ax.set_title(col)
        ax.grid(grid)
        if legend:
            ax.legend()

    set_ticks_props(axes,
                    xlabelsize=xlabelsize,
                    xrot=xrot,
                    ylabelsize=ylabelsize,
                    yrot=yrot)
    fig.subplots_adjust(wspace=0.3, hspace=0.3)

    return axes
Пример #10
0
def hist_frame(
    data,
    column=None,
    by=None,
    grid=True,
    xlabelsize=None,
    xrot=None,
    ylabelsize=None,
    yrot=None,
    ax=None,
    sharex=False,
    sharey=False,
    figsize=None,
    layout=None,
    bins=10,
    **kwds,
):
    # Start with empty pandas data frame derived from
    ed_df_bins, ed_df_weights = data._hist(num_bins=bins)

    converter._WARN = False  # no warning for pandas plots
    if by is not None:
        raise NotImplementedError("TODO")

    if column is not None:
        if not isinstance(column, (list, np.ndarray, ABCIndexClass)):
            column = [column]
        ed_df_bins = ed_df_bins[column]
        ed_df_weights = ed_df_weights[column]
    naxes = len(ed_df_bins.columns)

    if naxes == 0:
        raise ValueError("hist method requires numerical columns, "
                         "nothing to plot.")

    fig, axes = create_subplots(
        naxes=naxes,
        ax=ax,
        squeeze=False,
        sharex=sharex,
        sharey=sharey,
        figsize=figsize,
        layout=layout,
    )
    _axes = flatten_axes(axes)

    for i, col in enumerate(try_sort(data.columns)):
        ax = _axes[i]
        ax.hist(
            ed_df_bins[col][:-1],
            bins=ed_df_bins[col],
            weights=ed_df_weights[col],
            **kwds,
        )
        ax.set_title(col)
        ax.grid(grid)

    set_ticks_props(axes,
                    xlabelsize=xlabelsize,
                    xrot=xrot,
                    ylabelsize=ylabelsize,
                    yrot=yrot)
    fig.subplots_adjust(wspace=0.3, hspace=0.3)

    return axes
Пример #11
0
def boxplot_frame_groupby(
    grouped,
    subplots=True,
    column=None,
    fontsize=None,
    rot=0,
    grid=True,
    ax=None,
    figsize=None,
    layout=None,
    sharex=False,
    sharey=True,
    **kwds,
):
    if subplots is True:
        naxes = len(grouped)
        fig, axes = create_subplots(
            naxes=naxes,
            squeeze=False,
            ax=ax,
            sharex=sharex,
            sharey=sharey,
            figsize=figsize,
            layout=layout,
        )
        axes = flatten_axes(axes)

        ret = pd.Series(dtype=object)

        for (key, group), ax in zip(grouped, axes):
            d = group.boxplot(ax=ax,
                              column=column,
                              fontsize=fontsize,
                              rot=rot,
                              grid=grid,
                              **kwds)
            ax.set_title(pprint_thing(key))
            ret.loc[key] = d
        fig.subplots_adjust(bottom=0.15,
                            top=0.9,
                            left=0.1,
                            right=0.9,
                            wspace=0.2)
    else:
        keys, frames = zip(*grouped)
        if grouped.axis == 0:
            df = pd.concat(frames, keys=keys, axis=1)
        else:
            if len(frames) > 1:
                df = frames[0].join(frames[1::])
            else:
                df = frames[0]
        ret = df.boxplot(
            column=column,
            fontsize=fontsize,
            rot=rot,
            grid=grid,
            ax=ax,
            figsize=figsize,
            layout=layout,
            **kwds,
        )
    return ret