def plot_kde(data: Union["ps.DataFrame", "ps.Series"], **kwargs): from plotly import express import pyspark.pandas as ps if isinstance(data, ps.DataFrame) and "color" not in kwargs: kwargs["color"] = "names" kdf = KdePlotBase.prepare_kde_data(data) sdf = kdf._internal.spark_frame data_columns = kdf._internal.data_spark_columns ind = KdePlotBase.get_ind(sdf.select(*data_columns), kwargs.pop("ind", None)) bw_method = kwargs.pop("bw_method", None) pdfs = [] for label in kdf._internal.column_labels: pdfs.append( pd.DataFrame({ "Density": KdePlotBase.compute_kde( sdf.select(kdf._internal.spark_column_for(label)), ind=ind, bw_method=bw_method, ), "names": name_like_string(label), "index": ind, })) pdf = pd.concat(pdfs) fig = express.line(pdf, x="index", y="Density", **kwargs) fig["layout"]["xaxis"]["title"] = None return fig
def _plot(cls, ax, y, style=None, bw_method=None, ind=None, column_num=None, stacking_id=None, **kwds): y = KdePlotBase.compute_kde(y, bw_method=bw_method, ind=ind) lines = PandasMPLPlot._plot(ax, ind, y, style=style, **kwds) return lines
def _get_ind(self, y): return KdePlotBase.get_ind(y, self.ind)
def _compute_plot_data(self): self.data = KdePlotBase.prepare_kde_data(self.data)