Exemple #1
0
    def _check_axes_shape(self,
                          axes,
                          axes_num=None,
                          layout=None,
                          figsize=None):
        """
        Check expected number of axes is drawn in expected layout

        Parameters
        ----------
        axes : matplotlib Axes object, or its list-like
        axes_num : number
            expected number of axes. Unnecessary axes should be set to
            invisible.
        layout :  tuple
            expected layout, (expected number of rows , columns)
        figsize : tuple
            expected figsize. default is matplotlib default
        """
        if figsize is None:
            figsize = self.default_figsize
        visible_axes = self._flatten_visible(axes)

        if axes_num is not None:
            self.assertEqual(len(visible_axes), axes_num)
            for ax in visible_axes:
                # check something drawn on visible axes
                assert len(ax.get_children()) > 0

        if layout is not None:
            result = self._get_axes_layout(_flatten(axes))
            self.assertEqual(result, layout)

        tm.assert_numpy_array_equal(visible_axes[0].figure.get_size_inches(),
                                    np.array(figsize, dtype=np.float64))
    def _check_axes_shape(self, axes, axes_num=None, layout=None,
                          figsize=None):
        """
        Check expected number of axes is drawn in expected layout

        Parameters
        ----------
        axes : matplotlib Axes object, or its list-like
        axes_num : number
            expected number of axes. Unnecessary axes should be set to
            invisible.
        layout :  tuple
            expected layout, (expected number of rows , columns)
        figsize : tuple
            expected figsize. default is matplotlib default
        """
        if figsize is None:
            figsize = self.default_figsize
        visible_axes = self._flatten_visible(axes)

        if axes_num is not None:
            assert len(visible_axes) == axes_num
            for ax in visible_axes:
                # check something drawn on visible axes
                assert len(ax.get_children()) > 0

        if layout is not None:
            result = self._get_axes_layout(_flatten(axes))
            assert result == layout

        tm.assert_numpy_array_equal(
            visible_axes[0].figure.get_size_inches(),
            np.array(figsize, dtype=np.float64))
Exemple #3
0
    def _flatten_visible(self, axes):
        """
        Flatten axes, and filter only visible

        Parameters
        ----------
        axes : matplotlib Axes object, or its list-like

        """
        axes = _flatten(axes)
        axes = [ax for ax in axes if ax.get_visible()]
        return axes
    def _flatten_visible(self, axes):
        """
        Flatten axes, and filter only visible

        Parameters
        ----------
        axes : matplotlib Axes object, or its list-like

        """
        axes = _flatten(axes)
        axes = [ax for ax in axes if ax.get_visible()]
        return axes
Exemple #5
0
def _setup_subplots(
    subplots,
    nseries,
    sharex=False,
    sharey=False,
    figsize=None,
    ax=None,
    layout=None,
    layout_type="vertical",
):
    """prepares the subplots"""
    from pandas.plotting._tools import _subplots, _flatten

    if subplots:
        fig, axes = _subplots(
            naxes=nseries,
            sharex=sharex,
            sharey=sharey,
            figsize=figsize,
            ax=ax,
            layout=layout,
            layout_type=layout_type,
        )
    else:
        if ax is None:
            fig = plt.figure(figsize=figsize)
            axes = fig.add_subplot(111)
        else:
            fig = ax.get_figure()
            if figsize is not None:
                fig.set_size_inches(figsize)
            axes = ax

    axes = _flatten(axes)

    return fig, axes
Exemple #6
0
def _joyplot(data,
             grid=False,
             labels=None,
             sublabels=None,
             xlabels=True,
             xlabelsize=None,
             xrot=None,
             ylabelsize=None,
             yrot=None,
             ax=None,
             figsize=None,
             hist=False,
             bins=10,
             fade=False,
             xlim=None,
             ylim='max',
             fill=True,
             linecolor=None,
             overlap=1,
             background=None,
             range_style='all',
             x_range=None,
             tails=0.2,
             title=None,
             legend=False,
             loc="upper right",
             colormap=None,
             color=None,
             **kwargs):
    """
    Internal method.
    Draw a joyplot from an appropriately nested collection of lists
    using matplotlib and pandas.

    Parameters
    ----------
    data : DataFrame, Series or nested collection
    grid : boolean, default True
        Whether to show axis grid lines
    labels : boolean or list, default True.
        If list, must be the same size of the de
    xlabelsize : int, default None
        If specified changes the x-axis label size
    xrot : float, default None
        rotation of x axis labels
    ylabelsize : int, default None
        If specified changes the y-axis label size
    yrot : float, default None
        rotation of y axis labels
    ax : matplotlib axes object, default None
    figsize : tuple
        The size of the figure to create in inches by default
    hist : boolean, default False
    bins : integer, default 10
        Number of histogram bins to be used
    kwarg : other plotting keyword arguments
        To be passed to hist/kde plot function
    """

    if fill is True and linecolor is None:
        linecolor = "k"

    if sublabels is None:
        legend = False

    def _get_color(i, num_axes, j, num_subgroups):
        if isinstance(color, list):
            return color[j] if num_subgroups > 1 else color[i]
        elif color is not None:
            return color
        elif isinstance(colormap, list):
            return colormap[j](i / num_axes)
        elif color is None and colormap is None:
            num_cycle_colors = len(
                plt.rcParams['axes.prop_cycle'].by_key()['color'])
            return plt.rcParams['axes.prop_cycle'].by_key()['color'][
                j % num_cycle_colors]
        else:
            return colormap(i / num_axes)

    ygrid = (grid is True or grid == 'y' or grid == 'both')
    xgrid = (grid is True or grid == 'x' or grid == 'both')

    num_axes = len(data)

    if x_range is None:
        global_x_range = _x_range([v for g in data for sg in g for v in sg])
    else:
        global_x_range = _x_range(x_range, 0.0)
    global_x_min, global_x_max = min(global_x_range), max(global_x_range)

    # Each plot will have its own axis
    fig, axes = _subplots(naxes=num_axes,
                          ax=ax,
                          squeeze=False,
                          sharex=True,
                          sharey=False,
                          figsize=figsize,
                          layout_type='vertical')
    _axes = _flatten(axes)

    # The legend must be drawn in the last axis if we want it at the bottom.
    if loc in (3, 4, 8) or 'lower' in str(loc):
        legend_axis = num_axis - 1
    else:
        legend_axis = 0

    # A couple of simple checks.
    if labels is not None:
        assert len(labels) == num_axes
    if sublabels is not None:
        assert all(len(g) == len(sublabels) for g in data)
    if isinstance(color, list):
        assert all(len(g) <= len(color) for g in data)
    if isinstance(colormap, list):
        assert all(len(g) == len(colormap) for g in data)

    for i, group in enumerate(data):
        a = _axes[i]
        group_zorder = i
        if fade:
            kwargs['alpha'] = _get_alpha(i, num_axes)

        num_subgroups = len(group)

        if hist:
            # matplotlib hist() already handles multiple subgroups in a histogram
            a.hist(group,
                   label=sublabels,
                   bins=bins,
                   color=color,
                   range=[min(global_x_range),
                          max(global_x_range)],
                   edgecolor=linecolor,
                   zorder=group_zorder,
                   **kwargs)
        else:
            for j, subgroup in enumerate(group):

                # Compute the x_range of the current plot
                if range_style == 'all':
                    # All plots have the same range
                    x_range = global_x_range
                elif range_style == 'own':
                    # Each plot has its own range
                    x_range = _x_range(subgroup, tails)
                elif range_style == 'group':
                    # Each plot has a range that covers the whole group
                    x_range = _x_range(group, tails)
                elif isinstance(range_style, (list, np.ndarray)):
                    # All plots have exactly the range passed as argument
                    x_range = _x_range(range_style, 0.0)
                else:
                    raise NotImplementedError("Unrecognized range style.")

                if sublabels is None:
                    sublabel = None
                else:
                    sublabel = sublabels[j]

                element_zorder = group_zorder + j / (num_subgroups + 1)
                element_color = _get_color(i, num_axes, j, num_subgroups)

                plot_density(a,
                             x_range,
                             subgroup,
                             fill=fill,
                             linecolor=linecolor,
                             label=sublabel,
                             zorder=element_zorder,
                             color=element_color,
                             bins=bins,
                             **kwargs)

        # Setup the current axis: transparency, labels, spines.
        col_name = None if labels is None else labels[i]
        _setup_axis(a,
                    global_x_range,
                    col_name=col_name,
                    grid=ygrid,
                    ylabelsize=ylabelsize,
                    yrot=yrot)

        # When needed, draw the legend
        if legend and i == legend_axis:
            a.legend(loc=loc)
            # Bypass alpha values, in case
            for p in a.get_legend().get_patches():
                p.set_facecolor(p.get_facecolor())
                p.set_alpha(1.0)
            for l in a.get_legend().get_lines():
                l.set_alpha(1.0)

    # Final adjustments

    # Set the y limit for the density plots.
    # Since the y range in the subplots can vary significantly,
    # different options are available.
    if ylim == 'max':
        # Set all yaxis limit to the same value (max range among all)
        max_ylim = max(a.get_ylim()[1] for a in _axes)
        min_ylim = min(a.get_ylim()[0] for a in _axes)
        for a in _axes:
            a.set_ylim([min_ylim - 0.1 * (max_ylim - min_ylim), max_ylim])

    elif ylim == 'own':
        # Do nothing, each axis keeps its own ylim
        pass

    else:
        # Set all yaxis lim to the argument value ylim
        try:
            for a in _axes:
                a.set_ylim(ylim)
        except:
            print(
                "Warning: the value of ylim must be either 'max', 'own', or a tuple of length 2. The value you provided has no effect."
            )

    # Compute a final axis, used to apply global settings
    last_axis = fig.add_subplot(1, 1, 1)

    # Background color
    if background is not None:
        last_axis.patch.set_facecolor(background)

    for side in ['top', 'bottom', 'left', 'right']:
        last_axis.spines[side].set_visible(_DEBUG)

    # This looks hacky, but all the axes share the x-axis,
    # so they have the same lims and ticks
    last_axis.set_xlim(_axes[0].get_xlim())
    if xlabels is True:
        last_axis.set_xticks(np.array(_axes[0].get_xticks()[1:-1]))
        for t in last_axis.get_xticklabels():
            t.set_visible(True)
            t.set_fontsize(xlabelsize)
            t.set_rotation(xrot)

        # If grid is enabled, do not allow xticks (they are ugly)
        if xgrid:
            last_axis.tick_params(axis='both', which='both', length=0)
    else:
        last_axis.xaxis.set_visible(False)

    last_axis.yaxis.set_visible(False)
    last_axis.grid(xgrid)

    # Last axis on the back
    last_axis.zorder = min(a.zorder for a in _axes) - 1
    _axes = list(_axes) + [last_axis]

    if title is not None:
        plt.title(title)

    # The magic overlap happens here.
    h_pad = 5 + (-5 * (1 + overlap))
    fig.tight_layout(h_pad=h_pad)

    return fig, _axes
def _joyplot(data,
             grid=False,
             labels=None, sublabels=None,
             xlabels=True, label_strings = [],
             xlabelsize=None, xrot=None,
             ylabelsize=None, yrot=None,
             ax=None, figsize=None,
             hist=False, bins=10,
             fade=False,
             xlim=None, ylim='max',
             fill=True, linecolor=None,
             overlap=1, background=None,
             range_style='all', x_range=None, tails=0.2,
             title=None, x_spacing=None,
             legend=False, loc="upper right",
             colormap=None, color=None, x_title=None,
             **kwargs):
    """
    Internal method.
    Draw a joyplot from an appropriately nested collection of lists
    using matplotlib and pandas.

    Parameters
    ----------
    data : DataFrame, Series or nested collection
    grid : boolean, default True
        Whether to show axis grid lines
    labels : boolean or list, default True.
        If list, must be the same size of the de
    xlabelsize : int, default None
        If specified changes the x-axis label size
    xrot : float, default None
        rotation of x axis labels
    ylabelsize : int, default None
        If specified changes the y-axis label size
    yrot : float, default None
        rotation of y axis labels
    ax : matplotlib axes object, default None
    figsize : tuple
        The size of the figure to create in inches by default
    hist : boolean, default False
    bins : integer, default 10
        Number of histogram bins to be used
    kwarg : other plotting keyword arguments
        To be passed to hist/kde plot function
    """

    if fill is True and linecolor is None:
        linecolor = "k"

    if sublabels is None:
        legend = False

    def _get_color(i, num_axes, j, num_subgroups):
        if isinstance(color, list):
            return color[i]
        elif color is not None:
            return color
        elif isinstance(colormap, list):
            return colormap[j](i/num_axes)
        elif color is None and colormap is None:
            return plt.rcParams['axes.prop_cycle'].by_key()['color'][j]
        else:
            return colormap(i/num_axes)

    ygrid = (grid is True or grid == 'y' or grid == 'both')
    xgrid = (grid is True or grid == 'x' or grid == 'both')

    num_axes = len(data)

    if x_range is None:
        global_x_range = _x_range([v for g in data for sg in g for v in sg])
    else:
        global_x_range = _x_range(x_range, 0.0)
    global_x_min, global_x_max = min(global_x_range), max(global_x_range)

    # Each plot will have its own axis
    fig, axes = _subplots(naxes=num_axes, ax=ax, squeeze=False,
                          sharex=True, sharey=False, figsize=figsize,
                          layout_type='vertical')
    _axes = _flatten(axes)

    # The legend must be drawn in the last axis if we want it at the bottom.
    if loc in (3, 4, 8) or 'lower' in str(loc):
        legend_axis = num_axis - 1
    else:
        legend_axis = 0

    # A couple of simple checks.
    if labels is not None:
        assert len(labels) == num_axes
    if sublabels is not None:
        assert all(len(g) == len(sublabels) for g in data)
    # if isinstance(color, list):
    #     assert all(len(g) == len(color) for g in data)
    if isinstance(colormap, list):
        assert all(len(g) == len(colormap) for g in data)

    for i, group in enumerate(data):
        a = _axes[i]
        group_zorder = i
        if fade:
            kwargs['alpha'] = _get_alpha(i, num_axes)

        num_subgroups = len(group)

        if hist:
            # matplotlib hist() already handles multiple subgroups in a histogram
            a.hist(group, label=sublabels, bins=bins,
                   range=[min(global_x_range), max(global_x_range)],
                   edgecolor=linecolor, zorder=group_zorder, **kwargs)
        else:
            for j, subgroup in enumerate(group):

                # Compute the x_range of the current plot
                if range_style == 'all':
                # All plots have the same range
                    x_range = global_x_range
                elif range_style == 'own':
                # Each plot has its own range
                    x_range = _x_range(subgroup, tails)
                elif range_style == 'group':
                # Each plot has a range that covers the whole group
                    x_range = _x_range(group, tails)
                elif isinstance(range_style, (list, np.ndarray)):
                # All plots have exactly the range passed as argument
                    x_range = _x_range(range_style, 0.0)
                else:
                    raise NotImplementedError("Unrecognized range style.")

                if sublabels is None:
                    sublabel = None
                else:
                    sublabel = sublabels[j]

                element_zorder = group_zorder + j/(num_subgroups+1)
                element_color = _get_color(i, num_axes, j, num_subgroups)

                if not fill and linecolor is None:
                    linecolor = element_color

                print ("LABEL STRINGS ARE")
                print (label_strings)

                if len(label_strings) == 0:
                    plot_density(a, x_range, subgroup,
                                 fill=fill, linecolor=linecolor, label=sublabel,
                                 zorder=element_zorder, color=element_color,
                                 bins=bins, **kwargs)
                else:
                    print ('string is: ' + label_strings[i])
                    plot_density(a, x_range, subgroup,
                                 fill=fill, linecolor=linecolor, label=label_strings[i],
                                 zorder=element_zorder, color=element_color,
                                 bins=bins, **kwargs)


        # Setup the current axis: transparency, labels, spines.
        if labels is None:
            _setup_axis(a, global_x_range, col_name=None, grid=ygrid, x_spacing=x_spacing)
        else:
            if len(label_strings) == 0:
                _setup_axis(a, global_x_range, col_name=labels[i], grid=ygrid, x_spacing=x_spacing)
            else:
                _setup_axis(a, global_x_range, col_name=label_strings[i], grid=ygrid, x_spacing=x_spacing)

        # When needed, draw the legend
        if legend and i == legend_axis:
            a.legend(loc=loc)
            # Bypass alpha values, in case
            for p in a.get_legend().get_patches():
                p.set_alpha(1.0)
            for l in a.get_legend().get_lines():
                l.set_alpha(1.0)


    # Final adjustments

    # Set the y limit for the density plots.
    # Since the y range in the subplots can vary significantly,
    # different options are available.
    if ylim == 'max':
        # Set all yaxis limit to the same value (max range among all)
        max_ylim = max(a.get_ylim()[1] for a in _axes)
        min_ylim = min(a.get_ylim()[0] for a in _axes)
        for a in _axes:
            a.set_ylim([min_ylim - 0.1*(max_ylim-min_ylim), max_ylim])

    elif ylim == 'own':
        # Do nothing, each axis keeps its own ylim
        pass

    else:
        # Set all yaxis max lim to the argument value ylim
        try:
            for a in _axes:
                a.set_ylim(ylim)
        except:
            print("Warning: the value of ylim must be either 'max', 'own', or a tuple of length 2. The value you provided has no effect.")

    # Compute a final axis, used to apply global settings
    last_axis = fig.add_subplot(1, 1, 1)

    # Background color
    if background is not None:
        last_axis.patch.set_facecolor(background)

    for side in ['top', 'bottom', 'left', 'right']:
        last_axis.spines[side].set_visible(_DEBUG)

    # This looks hacky, but all the axes share the x-axis,
    # so they have the same lims and ticks
    last_axis.set_xlim(_axes[0].get_xlim())
    if xlabels is True:
        last_axis.set_xticks(_axes[0].get_xticks()[1:-1])
        last_axis.set_xticklabels(_axes[0].get_xticks()[1:-1])
        for t in last_axis.get_xticklabels():
            t.set_visible(True)

        # If grid is enabled, do not allow xticks (they are ugly)
        if xgrid:
            last_axis.tick_params(axis='both', which='both',length=0)
    else:
        last_axis.xaxis.set_visible(False)

    last_axis.yaxis.set_visible(False)
    last_axis.grid(xgrid)

    # set the x axis title if you want it
    if x_title is not None:
        last_axis.set_xlabel(x_title)

    # Last axis on the back
    last_axis.zorder = min(a.zorder for a in _axes) - 1
    _axes = list(_axes) + [last_axis]

    if title is not None:
        plt.title(title)

    # The magic overlap happens here.
    h_pad = 5 + (- 5*(1 + overlap))
    plt.tight_layout(h_pad=h_pad)

    return fig, _axes