def test_hist_plot_2d():
    fig, ax = plt.subplots()
    np.random.seed(0)
    data_x, data_y = np.random.randn(2, 10000)
    hist_plot_2d(ax, data_x, data_y)
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()
    assert xmin > -3 and xmax < 3 and ymin > -3 and ymax < 3

    hist_plot_2d(ax, data_x, data_y, xmin=-np.inf)
    hist_plot_2d(ax, data_x, data_y, xmax=np.inf)
    hist_plot_2d(ax, data_x, data_y, ymin=-np.inf)
    hist_plot_2d(ax, data_x, data_y, ymax=np.inf)
    assert xmin > -3 and xmax < 3 and ymin > -3 and ymax < 3

    data_x, data_y = np.random.uniform(-10, 10, (2, 1000000))
    weights = np.exp(-(data_x**2 + data_y**2) / 2)
    hist_plot_2d(ax, data_x, data_y, weights=weights, bins=30)
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()
    assert xmin > -3 and xmax < 3 and ymin > -3 and ymax < 3
    def plot(self, ax, paramname_x, paramname_y=None, *args, **kwargs):
        """Interface for 2D and 1D plotting routines.

        Produces a single 1D or 2D plot on an axis.

        Parameters
        ----------
        ax: matplotlib.axes.Axes
            Axes to plot on

        paramname_x: str
            Choice of parameter to plot on x-coordinate from self.columns.

        paramname_y: str
            Choice of parameter to plot on y-coordinate from self.columns.
            optional, if not provided, or the same as paramname_x, then 1D plot
            produced.

        plot_type: str
            Must be in {'kde', 'scatter', 'hist', 'fastkde'} for 2D plots and
            in {'kde', 'hist', 'fastkde', 'astropyhist'} for 1D plots.
            optional, (Default: 'kde')

        ncompress: int
            Number of samples to use in plotting routines.
            optional, Default dynamically chosen

        q: str, float, (float, float)
            Plot the `q` inner posterior quantiles in 1d 'kde' plots. To plot
            the full range, set `q=0` or `q=1`.
            * if str: any of {'1sigma', '2sigma', '3sigma', '4sigma', '5sigma'}
                Plot within mean +/- #sigma of posterior.
            * if float: Plot within the symmetric confidence interval
                `(1-q, q)`  or `(q, 1-q)`.
            * if tuple:  Plot within the (possibly asymmetric) confidence
                interval `q`.
            optional, (Default: '5sigma')

        Returns
        -------
        fig: matplotlib.figure.Figure
            New or original (if supplied) figure object

        axes: pandas.DataFrame or pandas.Series of matplotlib.axes.Axes
            Pandas array of axes objects

        """
        self._set_automatic_limits()
        plot_type = kwargs.pop('plot_type', 'kde')
        do_1d_plot = paramname_y is None or paramname_x == paramname_y
        kwargs['label'] = kwargs.get('label', self.label)
        ncompress = kwargs.pop('ncompress', None)

        if do_1d_plot:
            if paramname_x in self and plot_type is not None:
                xmin, xmax = self._limits(paramname_x)
                kwargs['xmin'] = kwargs.get('xmin', xmin)
                kwargs['xmax'] = kwargs.get('xmax', xmax)
                if plot_type == 'kde':
                    if ncompress is None:
                        ncompress = 1000
                    return kde_plot_1d(ax,
                                       self[paramname_x],
                                       weights=self.weights,
                                       ncompress=ncompress,
                                       *args,
                                       **kwargs)
                elif plot_type == 'fastkde':
                    x = self[paramname_x].compress(ncompress)
                    return fastkde_plot_1d(ax, x, *args, **kwargs)
                elif plot_type == 'hist':
                    return hist_plot_1d(ax,
                                        self[paramname_x],
                                        weights=self.weights,
                                        *args,
                                        **kwargs)
                elif plot_type == 'astropyhist':
                    x = self[paramname_x].compress(ncompress)
                    return hist_plot_1d(ax,
                                        x,
                                        plotter='astropyhist',
                                        *args,
                                        **kwargs)
                else:
                    raise NotImplementedError("plot_type is '%s', but must be"
                                              " one of {'kde', 'fastkde', "
                                              "'hist', 'astropyhist'}." %
                                              plot_type)
            else:
                ax.plot([], [])

        else:
            if (paramname_x in self and paramname_y in self
                    and plot_type is not None):
                xmin, xmax = self._limits(paramname_x)
                kwargs['xmin'] = kwargs.get('xmin', xmin)
                kwargs['xmax'] = kwargs.get('xmax', xmax)
                ymin, ymax = self._limits(paramname_y)
                kwargs['ymin'] = kwargs.get('ymin', ymin)
                kwargs['ymax'] = kwargs.get('ymax', ymax)
                if plot_type == 'kde':
                    if ncompress is None:
                        ncompress = 1000
                    x = self[paramname_x]
                    y = self[paramname_y]
                    return kde_contour_plot_2d(ax,
                                               x,
                                               y,
                                               weights=self.weights,
                                               ncompress=ncompress,
                                               *args,
                                               **kwargs)
                elif plot_type == 'fastkde':
                    x = self[paramname_x].compress(ncompress)
                    y = self[paramname_y].compress(ncompress)
                    return fastkde_contour_plot_2d(ax, x, y, *args, **kwargs)
                elif plot_type == 'scatter':
                    if ncompress is None:
                        ncompress = 500
                    x = self[paramname_x].compress(ncompress)
                    y = self[paramname_y].compress(ncompress)
                    return scatter_plot_2d(ax, x, y, *args, **kwargs)
                elif plot_type == 'hist':
                    x = self[paramname_x]
                    y = self[paramname_y]
                    return hist_plot_2d(ax,
                                        x,
                                        y,
                                        weights=self.weights,
                                        *args,
                                        **kwargs)
                else:
                    raise NotImplementedError("plot_type is '%s', but must be"
                                              "in {'kde', 'fastkde',"
                                              "'scatter', 'hist'}." %
                                              plot_type)

            else:
                ax.plot([], [])
Example #3
0
    def plot(self, ax, paramname_x, paramname_y=None, *args, **kwargs):
        """Interface for 2D and 1D plotting routines.

        Produces a single 1D or 2D plot on an axis.

        Parameters
        ----------
        ax: matplotlib.axes.Axes
            Axes to plot on

        paramname_x: str
            Choice of parameter to plot on x-coordinate from self.columns.

        paramname_y: str
            Choice of parameter to plot on y-coordinate from self.columns.
            optional, if not provided, or the same as paramname_x, then 1D plot
            produced.

        plot_type: str
            Must be in {'kde', 'scatter', 'hist', 'fastkde'} for 2D plots and
            in {'kde', 'hist', 'fastkde', 'astropyhist'} for 1D plots.
            optional, (Default: 'kde')

        ncompress: int
            Number of samples to use in plotting routines.
            optional, Default dynamically chosen

        Returns
        -------
        fig: matplotlib.figure.Figure
            New or original (if supplied) figure object

        axes: pandas.DataFrame or pandas.Series of matplotlib.axes.Axes
            Pandas array of axes objects

        """
        plot_type = kwargs.pop('plot_type', 'kde')
        do_1d_plot = paramname_y is None or paramname_x == paramname_y
        kwargs['label'] = kwargs.get('label', self.label)
        ncompress = kwargs.pop('ncompress', None)

        if do_1d_plot:
            if paramname_x in self and plot_type is not None:
                xmin, xmax = self._limits(paramname_x)
                if plot_type == 'kde':
                    if ncompress is None:
                        ncompress = 1000
                    return kde_plot_1d(ax,
                                       self[paramname_x],
                                       weights=self.weight,
                                       ncompress=ncompress,
                                       xmin=xmin,
                                       xmax=xmax,
                                       *args,
                                       **kwargs)
                elif plot_type == 'fastkde':
                    x = self[paramname_x].compress(ncompress)
                    return fastkde_plot_1d(ax,
                                           x,
                                           xmin=xmin,
                                           xmax=xmax,
                                           *args,
                                           **kwargs)
                elif plot_type == 'hist':
                    return hist_plot_1d(ax,
                                        self[paramname_x],
                                        weights=self.weight,
                                        xmin=xmin,
                                        xmax=xmax,
                                        *args,
                                        **kwargs)
                elif plot_type == 'astropyhist':
                    x = self[paramname_x].compress(ncompress)
                    return hist_plot_1d(ax,
                                        x,
                                        plotter='astropyhist',
                                        xmin=xmin,
                                        xmax=xmax,
                                        *args,
                                        **kwargs)
                else:
                    raise NotImplementedError("plot_type is '%s', but must be"
                                              " one of {'kde', 'fastkde', "
                                              "'hist', 'astropyhist'}." %
                                              plot_type)
            else:
                ax.plot([], [])

        else:
            if (paramname_x in self and paramname_y in self
                    and plot_type is not None):
                xmin, xmax = self._limits(paramname_x)
                ymin, ymax = self._limits(paramname_y)
                if plot_type == 'kde':
                    if ncompress is None:
                        ncompress = 1000
                    x = self[paramname_x]
                    y = self[paramname_y]
                    return kde_contour_plot_2d(ax,
                                               x,
                                               y,
                                               weights=self.weight,
                                               xmin=xmin,
                                               xmax=xmax,
                                               ymin=ymin,
                                               ymax=ymax,
                                               ncompress=ncompress,
                                               *args,
                                               **kwargs)
                elif plot_type == 'fastkde':
                    x = self[paramname_x].compress(ncompress)
                    y = self[paramname_y].compress(ncompress)
                    return fastkde_contour_plot_2d(ax,
                                                   x,
                                                   y,
                                                   xmin=xmin,
                                                   xmax=xmax,
                                                   ymin=ymin,
                                                   ymax=ymax,
                                                   *args,
                                                   **kwargs)
                elif plot_type == 'scatter':
                    if ncompress is None:
                        ncompress = 500
                    x = self[paramname_x].compress(ncompress)
                    y = self[paramname_y].compress(ncompress)
                    return scatter_plot_2d(ax,
                                           x,
                                           y,
                                           xmin=xmin,
                                           xmax=xmax,
                                           ymin=ymin,
                                           ymax=ymax,
                                           *args,
                                           **kwargs)
                elif plot_type == 'hist':
                    x = self[paramname_x]
                    y = self[paramname_y]
                    return hist_plot_2d(ax,
                                        x,
                                        y,
                                        weights=self.weight,
                                        xmin=xmin,
                                        xmax=xmax,
                                        ymin=ymin,
                                        ymax=ymax,
                                        *args,
                                        **kwargs)
                else:
                    raise NotImplementedError("plot_type is '%s', but must be"
                                              "in {'kde', 'fastkde',"
                                              "'scatter', 'hist'}." %
                                              plot_type)

            else:
                ax.plot([], [])