Beispiel #1
0
def test_scale_transformed_breaks():
    df = pd.DataFrame({'x': [1, 10, 100, 1000], 'y': range(4)})
    p = (ggplot(df, aes('x', 'y')) +
         geom_bin2d(breaks=([5, 50, 500], [0.5, 1.5, 2.5])))
    out1 = layer_data(p)
    out2 = layer_data(p + scale_x_log10())
    np.testing.assert_allclose(out1.xmax, [50, 500])
    np.testing.assert_allclose(out2.xmax, np.log10([50, 500]))
Beispiel #2
0
def test_scale_transformed_breaks():
    df = pd.DataFrame({
        'x': [1, 10, 100, 1000],
        'y': range(4)
    })
    p = (ggplot(df, aes('x', 'y'))
         + geom_bin2d(breaks=([5, 50, 500], [0.5, 1.5, 2.5]))
         )
    out1 = layer_data(p)
    out2 = layer_data(p + scale_x_log10())
    np.testing.assert_allclose(out1.xmax, [50, 500])
    np.testing.assert_allclose(out2.xmax, np.log10([50, 500]))
Beispiel #3
0
     p9.scale_y_timedelta(labels=timedelta_format("d")) + p9.annotate(
         "text",
         x=10,
         y=timedelta(days=1450),
         label=f"Y={results_2.slope:.2f}*X+{results_2.intercept:.2f}",
     ) + p9.labs(
         x="Euclidean Distance of Preprints First and Final Versions",
         y="Time Elapsed Until Preprint is Published",
     ) + p9.theme_seaborn(
         context="paper",
         style="ticks", font="Arial", font_scale=1.3))
print(g)

g = (p9.ggplot(published_date_distances,
               p9.aes(x="doc_distances", y="time_to_published")) +
     p9.geom_bin2d(bins=100) + p9.scale_fill_distiller(trans="log",
                                                       direction=-1,
                                                       type="seq",
                                                       palette="YlGnBu",
                                                       name="log(count)") +
     p9.geom_line(stat="smooth",
                  method="lm",
                  linetype="dashed",
                  se=False,
                  alpha=1,
                  size=0.7) +
     p9.scale_y_timedelta(labels=timedelta_format("d")) + p9.annotate(
         "text",
         x=7.5,
         y=timedelta(days=1490),
         label=f"Y={results_2.slope:.2f}*X+{results_2.intercept:.2f}",
Beispiel #4
0
def test_drop_false():
    p = ggplot(df, aes('x', 'y')) + geom_bin2d(binwidth=2, drop=False)
    assert p + _theme == 'drop_false'
Beispiel #5
0
def test_drop_true():
    p = ggplot(df, aes('x', 'y')) + geom_bin2d(binwidth=2, drop=True)
    assert p + _theme == 'drop_true'
Beispiel #6
0
def test_drop_false():
    p = ggplot(df, aes('x', 'y')) + geom_bin2d(binwidth=2, drop=False)
    assert p + _theme == 'drop_false'
Beispiel #7
0
def test_drop_true():
    p = ggplot(df, aes('x', 'y')) + geom_bin2d(binwidth=2, drop=True)
    assert p + _theme == 'drop_true'
def histogram2d(
    data: pd.DataFrame,
    column1: str,
    column2: str,
    fig: plt.Figure = None,
    ax: plt.Axes = None,
    fig_width: int = 6,
    fig_height: int = 6,
    trend_line: str = "auto",
    lower_quantile1: float = 0,
    upper_quantile1: float = 1,
    lower_quantile2: float = 0,
    upper_quantile2: float = 1,
    transform1: str = "identity",
    transform2: str = "identity",
    equalize_axes: bool = False,
    reference_line: bool = False,
    plot_density: bool = False,
) -> Tuple[plt.Figure, plt.Axes, p9.ggplot]:
    """
    Creates an EDA plot for two continuous variables.

    Args:
        data: pandas DataFrame containing data to be plotted
        column1: name of column to plot on the x axis
        column2: name of column to plot on the y axis
        fig: matplotlib Figure generated from blank ggplot to plot onto. If specified, must also specify ax
        ax: matplotlib axes generated from blank ggplot to plot onto. If specified, must also specify fig
        fig_width: figure width in inches
        fig_height: figure height in inches
        trend_line: Trend line to plot over data. Default is to plot no trend line. Other options are passed
            to `geom_smooth <https://plotnine.readthedocs.io/en/stable/generated/plotnine.geoms.geom_smooth.html>`_.
        lower_quantile1: Lower quantile of column1 data to remove before plotting for ignoring outliers
        upper_quantile1: Upper quantile of column1 data to remove before plotting for ignoring outliers
        lower_quantile2: Lower quantile of column2 data to remove before plotting for ignoring outliers
        upper_quantile2: Upper quantile of column2 data to remove before plotting for ignoring outliers
        transform1: Transformation to apply to the column1 data for plotting:

         - **'identity'**: no transformation
         - **'log'**: apply a logarithmic transformation with small constant added in case of zero values
         - **'log_exclude0'**: apply a logarithmic transformation with zero values removed
         - **'sqrt'**: apply a square root transformation
        transform2: Transformation to apply to the column2 data for plotting. Same options as for column1.
        equalize_axes: Square the aspect ratio and match the axis limits
        reference_line: Add a y = x reference line
        plot_density: Overlay a 2d density on the given plot

    Returns:
        Tuple containing matplotlib figure and axes along with the plotnine ggplot object

    Examples:
        .. plot::

            import pandas as pd
            import intedact
            data = pd.read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2018/2018-09-11/cats_vs_dogs.csv")
            intedact.histogram2d(data, 'n_dog_households', 'n_cat_households', equalize_axes=True, reference_line=True);
    """
    data = trim_quantiles(data,
                          column1,
                          lower_quantile=lower_quantile1,
                          upper_quantile=upper_quantile1)
    data = trim_quantiles(data,
                          column2,
                          lower_quantile=lower_quantile2,
                          upper_quantile=upper_quantile2)
    data = preprocess_transformations(data, column1, transform=transform1)
    data = preprocess_transformations(data, column2, transform=transform2)

    # draw the scatterplot
    gg = p9.ggplot(data, p9.aes(x=column1, y=column2)) + p9.geom_bin2d()

    # overlay density
    if plot_density:
        gg += p9.geom_density_2d()

    # add reference line
    if reference_line:
        gg += p9.geom_abline(color="black")

    # add trend line
    if trend_line != "none":
        gg += p9.geom_smooth(method=trend_line, color="red")

    gg += p9.labs(fill="")

    # handle axes transforms
    gg, xlabel = transform_axis(gg, column1, transform1, xaxis=True)
    gg, ylabel = transform_axis(gg, column2, transform2, xaxis=False)

    if fig is None and ax is None:
        gg.draw()
        fig = plt.gcf()
        ax = fig.axes[0]
    else:
        _ = gg._draw_using_figure(fig, [ax])

    if equalize_axes:
        fig, ax, gg = match_axes(fig, ax, gg)
        fig.set_size_inches(fig_width, fig_width)
    else:
        fig.set_size_inches(fig_width, fig_height)

    ax.set_ylabel(ylabel)
    ax.set_xlabel(xlabel)

    return fig, ax, gg
     p9.scale_y_timedelta(labels=timedelta_format("d")) + p9.annotate(
         "text",
         x=10,
         y=timedelta(days=1450),
         label=f"Y={results_2.slope:.2f}*X+{results_2.intercept:.2f}",
     ) + p9.labs(
         x="Eucledian Distance of Preprints First and Final Versions",
         y="Time Elapsed Until Preprint is Published",
     ) + p9.theme_seaborn(
         context="paper",
         style="ticks", font="Arial", font_scale=1.3))
print(g)

g = (p9.ggplot(published_date_distances,
               p9.aes(x="doc_distances", y="time_to_published")) +
     p9.geom_bin2d(bins=100) + p9.scale_fill_distiller(
         trans="log",
         direction=-1,
         type="seq",
         palette="YlGnBu",
         name="log(count)",
         labels=log_format(base=10),
     ) + p9.geom_line(stat="smooth",
                      method="lm",
                      linetype="dashed",
                      se=False,
                      alpha=1,
                      size=0.7) +
     p9.scale_y_timedelta(labels=timedelta_format("d")) + p9.annotate(
         "text",
         x=10,