コード例 #1
0
ファイル: stats.py プロジェクト: vipul1409/handyspark
def entropy(sdf, colnames):
    colnames = ensure_list(colnames)
    check_columns(sdf, colnames)
    entropy = []
    for colname in colnames:
        entropy.append(
            distribution(sdf, colname).select(
                F.sum(F.expr('-log2(__probability)*__probability'))).take(1)[0]
            [0])
    return pd.Series(entropy, index=colnames)
コード例 #2
0
def scatterplot(sdf, col1, col2, n=30, ax=None):
    strat_ax, data = sdf._get_strata()
    if data is None:
        data = strat_scatterplot(sdf, col1, col2, n)
    else:
        ax = strat_ax
    model, total = data

    if ax is None:
        fig, ax = plt.subplots(1, 1)

    axes = ensure_list(ax)
    clauses = sdf._handy._strata_raw_clauses
    if not len(clauses):
        clauses = [None]

    bucket_name1, bucket_name2 = '__{}_bucket'.format(
        col1), '__{}_bucket'.format(col2)
    strata = sdf._handy.strata_colnames
    colnames = strata + [bucket_name1, bucket_name2]
    result = model.transform(sdf).select(colnames).groupby(colnames).agg(
        F.count('*').alias('count')).toPandas().sort_values(by=colnames)

    splits = [bucket.getSplits() for bucket in model.stages]
    splits = [
        list(map(np.mean, zip(split[1:], split[:-1]))) for split in splits
    ]
    splits1 = pd.DataFrame({bucket_name1: np.arange(0, n), col1: splits[0]})
    splits2 = pd.DataFrame({bucket_name2: np.arange(0, n), col2: splits[1]})

    df_counts = result.merge(splits1).merge(splits2)[
        strata + [col1, col2, 'count']].rename(columns={'count': 'Proportion'})

    df_counts.loc[:, 'Proportion'] = df_counts.Proportion.apply(
        lambda p: round(p / total, 4))

    for ax, clause in zip(axes, clauses):
        data = df_counts
        if clause is not None:
            data = data.query(clause)
        sns.scatterplot(data=data,
                        x=col1,
                        y=col2,
                        size='Proportion',
                        ax=ax,
                        legend=False)

    if len(axes) == 1:
        axes = axes[0]

    return axes
コード例 #3
0
ファイル: stats.py プロジェクト: vipul1409/handyspark
def StatisticalSummaryValues(sdf, colnames):
    """Builds a Java StatisticalSummaryValues object for each column
    """
    colnames = ensure_list(colnames)
    check_columns(sdf, colnames)

    jvm = sdf._sc._jvm
    summ = sdf.notHandy().select(colnames).describe().toPandas().set_index(
        'summary')
    ssvs = {}
    for colname in colnames:
        values = list(map(float, summ[colname].values))
        values = values[1], np.sqrt(values[2]), int(
            values[0]), values[4], values[3], values[0] * values[1]
        java_class = jvm.org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues
        ssvs.update({colname: java_class(*values)})
    return ssvs
コード例 #4
0
def strat_histogram(sdf, colname, bins=10, categorical=False):
    if categorical:
        result = sdf.cols[colname]._value_counts(dropna=False, raw=True)

        if hasattr(result.index, 'levels'):
            indexes = pd.MultiIndex.from_product(
                result.index.levels[:-1] +
                [result.reset_index()[colname].unique().tolist()],
                names=result.index.names)
            result = (pd.DataFrame(index=indexes).join(
                result.to_frame(),
                how='left').fillna(0)[result.name].astype(result.dtype))

        start_values = result.index.tolist()
    else:
        bucket_name = '__{}_bucket'.format(colname)
        strata = sdf._handy.strata_colnames
        colnames = strata + ensure_list(bucket_name)

        start_values = np.linspace(
            *sdf.agg(F.min(colname),
                     F.max(colname)).rdd.map(tuple).collect()[0], bins + 1)
        bucketizer = Bucketizer(splits=start_values,
                                inputCol=colname,
                                outputCol=bucket_name,
                                handleInvalid="skip")
        result = (
            bucketizer.transform(sdf).select(colnames).groupby(colnames).agg(
                F.count('*').alias('count')).toPandas().sort_values(
                    by=colnames))

        indexes = pd.DataFrame({
            bucket_name: np.arange(0, bins),
            'bucket': start_values[:-1]
        })
        if len(strata):
            indexes = (indexes.assign(key=1).merge(
                result[strata].drop_duplicates().assign(key=1),
                on='key').drop(columns=['key']))
        result = indexes.merge(result, how='left', on=strata +
                               [bucket_name]).fillna(0)[strata +
                                                        [bucket_name, 'count']]

    return start_values, result
コード例 #5
0
def histogram(sdf, colname, bins=10, categorical=False, ax=None):
    strat_ax, data = sdf._get_strata()
    if data is None:
        data = strat_histogram(sdf, colname, bins, categorical)
    else:
        ax = strat_ax
    start_values, counts = data

    if ax is None:
        fig, ax = plt.subplots(1, 1)

    axes = ensure_list(ax)
    clauses = sdf._handy._strata_raw_clauses
    if not len(clauses):
        clauses = [None]

    for ax, clause in zip(axes, clauses):
        if categorical:
            pdf = counts.sort_index().to_frame()
            if clause is not None:
                pdf = pdf.query(clause).reset_index(
                    sdf._handy.strata_colnames).drop(
                        columns=sdf._handy.strata_colnames)
            pdf.iloc[:bins].plot(kind='bar',
                                 color='C0',
                                 legend=False,
                                 rot=0,
                                 ax=ax,
                                 title=colname)
        else:
            mid_point_bins = start_values[:-1]
            weights = counts
            if clause is not None:
                weights = counts.query(clause)
            ax.hist(mid_point_bins,
                    bins=start_values,
                    weights=weights['count'].values)
            ax.set_title(colname)

    if len(axes) == 1:
        axes = axes[0]

    return axes