예제 #1
0
파일: base.py 프로젝트: soumya1984/spark
def transform_boolean_operand_to_numeric(operand: Any,
                                         *,
                                         spark_type: Optional[DataType] = None
                                         ) -> Any:
    """Transform boolean operand to numeric.

    If the `operand` is:
        - a boolean IndexOpsMixin, transform the `operand` to the `spark_type`.
        - a boolean literal, transform to the int value.
    Otherwise, return the operand as it is.
    """
    from pyspark.pandas.base import IndexOpsMixin

    if isinstance(operand, IndexOpsMixin) and isinstance(
            operand.spark.data_type, BooleanType):
        assert spark_type, "spark_type must be provided if the operand is a boolean IndexOpsMixin"
        assert isinstance(spark_type,
                          NumericType), "spark_type must be NumericType"
        dtype = spark_type_to_pandas_dtype(
            spark_type,
            use_extension_dtypes=operand._internal.data_fields[0].
            is_extension_dtype)
        return operand._with_new_scol(
            operand.spark.column.cast(spark_type),
            field=operand._internal.data_fields[0].copy(dtype=dtype,
                                                        spark_type=spark_type),
        )
    elif isinstance(operand, bool):
        return int(operand)
    else:
        return operand
예제 #2
0
파일: utils.py 프로젝트: aaalan321/spark
def combine_frames(this, *args, how="full", preserve_order_column=False):
    """
    This method combines `this` DataFrame with a different `that` DataFrame or
    Series from a different DataFrame.

    It returns a DataFrame that has prefix `this_` and `that_` to distinct
    the columns names from both DataFrames

    It internally performs a join operation which can be expensive in general.
    So, if `compute.ops_on_diff_frames` option is False,
    this method throws an exception.
    """
    from pyspark.pandas.config import get_option
    from pyspark.pandas.frame import DataFrame
    from pyspark.pandas.internal import (
        InternalFrame,
        HIDDEN_COLUMNS,
        NATURAL_ORDER_COLUMN_NAME,
        SPARK_INDEX_NAME_FORMAT,
    )
    from pyspark.pandas.series import Series

    if all(isinstance(arg, Series) for arg in args):
        assert all(
            same_anchor(arg, args[0]) for arg in args
        ), "Currently only one different DataFrame (from given Series) is supported"
        assert not same_anchor(
            this, args[0]), "We don't need to combine. All series is in this."
        that = args[0]._kdf[list(args)]
    elif len(args) == 1 and isinstance(args[0], DataFrame):
        assert isinstance(args[0], DataFrame)
        assert not same_anchor(
            this,
            args[0]), "We don't need to combine. `this` and `that` are same."
        that = args[0]
    else:
        raise AssertionError("args should be single DataFrame or "
                             "single/multiple Series")

    if get_option("compute.ops_on_diff_frames"):

        def resolve(internal, side):
            rename = lambda col: "__{}_{}".format(side, col)
            internal = internal.resolved_copy
            sdf = internal.spark_frame
            sdf = internal.spark_frame.select([
                scol_for(sdf, col).alias(rename(col))
                for col in sdf.columns if col not in HIDDEN_COLUMNS
            ] + list(HIDDEN_COLUMNS))
            return internal.copy(
                spark_frame=sdf,
                index_spark_columns=[
                    scol_for(sdf, rename(col))
                    for col in internal.index_spark_column_names
                ],
                data_spark_columns=[
                    scol_for(sdf, rename(col))
                    for col in internal.data_spark_column_names
                ],
            )

        this_internal = resolve(this._internal, "this")
        that_internal = resolve(that._internal, "that")

        this_index_map = list(
            zip(
                this_internal.index_spark_column_names,
                this_internal.index_names,
                this_internal.index_dtypes,
            ))
        that_index_map = list(
            zip(
                that_internal.index_spark_column_names,
                that_internal.index_names,
                that_internal.index_dtypes,
            ))
        assert len(this_index_map) == len(that_index_map)

        join_scols = []
        merged_index_scols = []

        # Note that the order of each element in index_map is guaranteed according to the index
        # level.
        this_and_that_index_map = list(zip(this_index_map, that_index_map))

        this_sdf = this_internal.spark_frame.alias("this")
        that_sdf = that_internal.spark_frame.alias("that")

        # If the same named index is found, that's used.
        index_column_names = []
        index_use_extension_dtypes = []
        for (
                i,
            ((this_column, this_name, this_dtype), (that_column, that_name,
                                                    that_dtype)),
        ) in enumerate(this_and_that_index_map):
            if this_name == that_name:
                # We should merge the Spark columns into one
                # to mimic pandas' behavior.
                this_scol = scol_for(this_sdf, this_column)
                that_scol = scol_for(that_sdf, that_column)
                join_scol = this_scol == that_scol
                join_scols.append(join_scol)

                column_name = SPARK_INDEX_NAME_FORMAT(i)
                index_column_names.append(column_name)
                index_use_extension_dtypes.append(
                    any(
                        isinstance(dtype, extension_dtypes)
                        for dtype in [this_dtype, that_dtype]))
                merged_index_scols.append(
                    F.when(this_scol.isNotNull(),
                           this_scol).otherwise(that_scol).alias(column_name))
            else:
                raise ValueError(
                    "Index names must be exactly matched currently.")

        assert len(
            join_scols) > 0, "cannot join with no overlapping index names"

        joined_df = this_sdf.join(that_sdf, on=join_scols, how=how)

        if preserve_order_column:
            order_column = [scol_for(this_sdf, NATURAL_ORDER_COLUMN_NAME)]
        else:
            order_column = []

        joined_df = joined_df.select(merged_index_scols + [
            scol_for(this_sdf, this_internal.spark_column_name_for(label))
            for label in this_internal.column_labels
        ] + [
            scol_for(that_sdf, that_internal.spark_column_name_for(label))
            for label in that_internal.column_labels
        ] + order_column)

        index_spark_columns = [
            scol_for(joined_df, col) for col in index_column_names
        ]
        index_dtypes = [
            spark_type_to_pandas_dtype(
                field.dataType, use_extension_dtypes=use_extension_dtypes)
            for field, use_extension_dtypes in zip(
                joined_df.select(index_spark_columns).schema,
                index_use_extension_dtypes)
        ]

        index_columns = set(index_column_names)
        new_data_columns = [
            col for col in joined_df.columns
            if col not in index_columns and col != NATURAL_ORDER_COLUMN_NAME
        ]
        data_dtypes = this_internal.data_dtypes + that_internal.data_dtypes

        level = max(this_internal.column_labels_level,
                    that_internal.column_labels_level)

        def fill_label(label):
            if label is None:
                return ([""] * (level - 1)) + [None]
            else:
                return ([""] * (level - len(label))) + list(label)

        column_labels = [
            tuple(["this"] + fill_label(label))
            for label in this_internal.column_labels
        ] + [
            tuple(["that"] + fill_label(label))
            for label in that_internal.column_labels
        ]
        column_label_names = ([None] *
                              (1 + level - this_internal.column_labels_level)
                              ) + this_internal.column_label_names
        return DataFrame(
            InternalFrame(
                spark_frame=joined_df,
                index_spark_columns=index_spark_columns,
                index_names=this_internal.index_names,
                index_dtypes=index_dtypes,
                column_labels=column_labels,
                data_spark_columns=[
                    scol_for(joined_df, col) for col in new_data_columns
                ],
                data_dtypes=data_dtypes,
                column_label_names=column_label_names,
            ))
    else:
        raise ValueError(ERROR_MESSAGE_CANNOT_COMBINE)