def plot_kde(data: Union["ps.DataFrame", "ps.Series"], **kwargs): from plotly import express import pyspark.pandas as ps if isinstance(data, ps.DataFrame) and "color" not in kwargs: kwargs["color"] = "names" kdf = KdePlotBase.prepare_kde_data(data) sdf = kdf._internal.spark_frame data_columns = kdf._internal.data_spark_columns ind = KdePlotBase.get_ind(sdf.select(*data_columns), kwargs.pop("ind", None)) bw_method = kwargs.pop("bw_method", None) pdfs = [] for label in kdf._internal.column_labels: pdfs.append( pd.DataFrame({ "Density": KdePlotBase.compute_kde( sdf.select(kdf._internal.spark_column_for(label)), ind=ind, bw_method=bw_method, ), "names": name_like_string(label), "index": ind, })) pdf = pd.concat(pdfs) fig = express.line(pdf, x="index", y="Density", **kwargs) fig["layout"]["xaxis"]["title"] = None return fig
def _get_ind(self, y): return KdePlotBase.get_ind(y, self.ind)