예제 #1
0
    def __init__(self,
                 marker_column: str,
                 marker_start: Any,
                 marker_end: Any = NONEVALUE,
                 marker_start_use_first: bool = False,
                 marker_end_use_first: bool = True,
                 orderby_columns: TYPE_COLUMNS = None,
                 groupby_columns: TYPE_COLUMNS = None,
                 ascending: TYPE_ASCENDING = None,
                 result_type: str = "enumerated",
                 target_column_name: str = "iids"):

        self.marker_column = marker_column
        self.marker_start = marker_start
        self.marker_end = marker_end
        self.marker_start_use_first = marker_start_use_first
        self.marker_end_use_first = marker_end_use_first
        self.orderby_columns = sanitizer.ensure_iterable(orderby_columns)
        self.groupby_columns = sanitizer.ensure_iterable(groupby_columns)
        self.ascending = sanitizer.ensure_iterable(ascending)
        self.result_type = result_type
        self.target_column_name = target_column_name

        # check correct result type
        valid_result_types = {"raw", "valid", "enumerated"}
        if result_type not in valid_result_types:
            raise ValueError("Parameter `result_type` is invalid with: {}. "
                             "Allowed arguments are: {}".format(
                                 result_type, valid_result_types))

        # check for identical start and end values
        self._identical_start_end_markers = ((marker_end == NONEVALUE)
                                             or (marker_start == marker_end))

        # sanity checks for sort order
        if self.ascending:

            # check for equal number of items of order and sort columns
            if len(self.orderby_columns) != len(self.ascending):
                raise ValueError('`order_columns` and `ascending` must have '
                                 'equal number of items.')

            # check for correct sorting keywords
            if not all([isinstance(x, bool) for x in self.ascending]):
                raise ValueError('Only `True` and `False` are '
                                 'allowed arguments for `ascending`')

        # set default sort order if None is given
        elif self.orderby_columns:
            self.ascending = [True] * len(self.orderby_columns)
예제 #2
0
    def disable(self, stages: List[TYPE_IDENTIFIER]) -> None:
        """Disable pipeline caching for given stages. Stage can be identified
        via index, identifier or stage itself.

        If pipeline was already transformed, disables caching on existing
        dataframe representations. However, `transform` has to be called again
        for the execution plan of the pipeline's result dataframe to respect
        caching changes.

        Parameters
        ----------
        stages: iterable
            Iterable of int, str or Transformer.

        """

        stages = ensure_iterable(stages)

        for stage in stages:
            idx = self.pipeline._loc.get_index_location(stage)

            try:
                self._store.remove(idx)
            except KeyError:
                raise ValueError("'{}' does not exist in cache and hence"
                                 "cannot be disabled.".format(stage))

            if self.pipeline._transformer:
                self.pipeline(idx).unpersist(blocking=True)
예제 #3
0
def test_ensure_iterable_custom_class(seq_type):
    class Dummy:
        pass

    dummy = Dummy()

    test_input = dummy
    test_output = seq_type([dummy])

    assert ensure_iterable(test_input, seq_type) == test_output
예제 #4
0
    def profile(self, *dfs: pd.DataFrame, **kwargs):
        """Profiles the actual memory usage given input dataframes `dfs`
        which are passed to `fit_transform`.

        """

        # usage input
        self._usage_input = self._memory_usage_dfs(*dfs)

        # usage output
        dfs_output = self._wrangler.fit_transform(*dfs)
        dfs_output = sanitizer.ensure_iterable(dfs_output)
        self._usage_output = self._memory_usage_dfs(*dfs_output)

        # usage during fit_transform
        super().profile(*dfs, **kwargs)

        return self
예제 #5
0
def prepare_orderby(orderby_columns: TYPE_PYSPARK_COLUMNS,
                    ascending: TYPE_ASCENDING = True,
                    reverse: bool = False) -> List[Column]:
    """Convenient function to return orderby columns in correct
    ascending/descending order.

    Parameters
    ----------
    orderby_columns: TYPE_PYSPARK_COLUMNS
        Columns to explicitly apply an order to.
    ascending: TYPE_ASCENDING, optional
        Define order of columns via bools. True and False refer to ascending
        and descending, respectively.
    reverse: bool, optional
        Reverse the given order. By default, not activated.

    Returns
    -------
    ordered: list
        List of order columns.

    """

    # ensure columns
    orderby_columns = ensure_iterable(orderby_columns)
    orderby_columns = [ensure_column(column) for column in orderby_columns]

    # check if only True/False is given broadcast
    if isinstance(ascending, bool):
        ascending = [ascending] * len(orderby_columns)

    # ensure equal lengths, otherwise raise
    elif len(orderby_columns) != len(ascending):
        raise ValueError('`orderby_columns` and `ascending` must have '
                         'equal number of items.')

    zipped = zip(orderby_columns, ascending)

    def boolify(sort_ascending: Optional[bool]) -> bool:
        return bool(sort_ascending) != reverse

    return [column.asc() if boolify(sort_ascending) else column.desc()
            for column, sort_ascending in zipped]
예제 #6
0
def validate_columns(df: pd.DataFrame, columns: TYPE_COLUMNS):
    """Check that columns exist in dataframe and raise error if otherwise.

    Parameters
    ----------
    df: pd.DataFrame
        Dataframe to check against.
    columns: iterable[str]
        Columns to be validated.

    """

    columns = ensure_iterable(columns)

    for column in columns:
        if column not in df.columns:
            raise ValueError('Column with name `{}` does not exist. '
                             'Please check parameter settings.'
                             .format(column))
예제 #7
0
    def enable(self, stages: List[TYPE_IDENTIFIER]) -> None:
        """Enable pipeline caching for given stages. Stage can be identified
        via index, identifier or stage itself.

        If pipeline was already transformed, enables caching on existing
        dataframe representations. However, `transform` has to be called again
        for the execution plan of the pipeline's result dataframe to respect
        caching changes.

        Parameters
        ----------
        stages: iterable
            Iterable of int, str or Transformer.

        """

        stages = ensure_iterable(stages)

        for stage in stages:
            idx = self.pipeline._loc.get_index_location(stage)
            self._store.add(idx)

            if self.pipeline._transformer:
                self.pipeline(idx).cache()
예제 #8
0
def test_ensure_iterable_none(seq_type):

    assert ensure_iterable(None, seq_type) == seq_type()
    assert ensure_iterable(None, seq_type, retain_none=True) is None
예제 #9
0
def test_ensure_iterable_strings(seq_type):
    test_input = ["test1", "test2"]
    test_output = seq_type(["test1", "test2"])

    assert ensure_iterable(test_input, seq_type) == test_output
예제 #10
0
def test_ensure_iterable_string(seq_type):
    test_input = "test_string"
    test_output = seq_type(["test_string"])

    assert ensure_iterable(test_input, seq_type) == test_output
예제 #11
0
def test_ensure_iterable_number(seq_type):
    test_input = 3
    test_output = seq_type([3])

    assert ensure_iterable(test_input, seq_type) == test_output