예제 #1
0
파일: plotly.py 프로젝트: ynuosoft/spark
def plot_histogram(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
    import plotly.graph_objs as go

    bins = kwargs.get("bins", 10)
    kdf, bins = HistogramPlotBase.prepare_hist_data(data, bins)
    assert len(bins) > 2, "the number of buckets must be higher than 2."
    output_series = HistogramPlotBase.compute_hist(kdf, bins)
    prev = float("%.9f" % bins[0])  # to make it prettier, truncate.
    text_bins = []
    for b in bins[1:]:
        norm_b = float("%.9f" % b)
        text_bins.append("[%s, %s)" % (prev, norm_b))
        prev = norm_b
    text_bins[
        -1] = text_bins[-1][:-1] + "]"  # replace ) to ] for the last bucket.

    bins = 0.5 * (bins[:-1] + bins[1:])

    output_series = list(output_series)
    bars = []
    for series in output_series:
        bars.append(
            go.Bar(
                x=bins,
                y=series,
                name=name_like_string(series.name),
                text=text_bins,
                hovertemplate=("variable=" + name_like_string(series.name) +
                               "<br>value=%{text}<br>count=%{y}"),
            ))

    fig = go.Figure(data=bars, layout=go.Layout(barmode="stack"))
    fig["layout"]["xaxis"]["title"] = "value"
    fig["layout"]["yaxis"]["title"] = "count"
    return fig
예제 #2
0
def plot_histogram(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
    import plotly.graph_objs as go
    import pyspark.pandas as ps

    bins = kwargs.get("bins", 10)
    y = kwargs.get("y")
    if y and isinstance(data, ps.DataFrame):
        # Note that the results here are matched with matplotlib. x and y
        # handling is different from pandas' plotly output.
        data = data[y]
    psdf, bins = HistogramPlotBase.prepare_hist_data(data, bins)
    assert len(bins) > 2, "the number of buckets must be higher than 2."
    output_series = HistogramPlotBase.compute_hist(psdf, bins)
    prev = float("%.9f" % bins[0])  # to make it prettier, truncate.
    text_bins = []
    for b in bins[1:]:
        norm_b = float("%.9f" % b)
        text_bins.append("[%s, %s)" % (prev, norm_b))
        prev = norm_b
    text_bins[
        -1] = text_bins[-1][:-1] + "]"  # replace ) to ] for the last bucket.

    bins = 0.5 * (bins[:-1] + bins[1:])

    output_series = list(output_series)
    bars = []
    for series in output_series:
        bars.append(
            go.Bar(
                x=bins,
                y=series,
                name=name_like_string(series.name),
                text=text_bins,
                hovertemplate=("variable=" + name_like_string(series.name) +
                               "<br>value=%{text}<br>count=%{y}"),
            ))

    layout_keys = inspect.signature(go.Layout).parameters.keys()
    layout_kwargs = {k: v for k, v in kwargs.items() if k in layout_keys}

    fig = go.Figure(data=bars, layout=go.Layout(**layout_kwargs))
    fig["layout"]["barmode"] = "stack"
    fig["layout"]["xaxis"]["title"] = "value"
    fig["layout"]["yaxis"]["title"] = "count"
    return fig
예제 #3
0
 def _compute_plot_data(self):
     self.data, self.bins = HistogramPlotBase.prepare_hist_data(
         self.data, self.bins)