コード例 #1
0
    def __init__(
        self,
        on: Optional[KeyFn] = None,
        ddof: int = 1,
        ignore_nulls: bool = True,
    ):
        self._set_key_fn(on)

        def merge(a: List[float], b: List[float]):
            # Merges two accumulations into one.
            # See
            # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
            M2_a, mean_a, count_a = a
            M2_b, mean_b, count_b = b
            delta = mean_b - mean_a
            count = count_a + count_b
            # NOTE: We use this mean calculation since it's more numerically
            # stable than mean_a + delta * count_b / count, which actually
            # deviates from Pandas in the ~15th decimal place and causes our
            # exact comparison tests to fail.
            mean = (mean_a * count_a + mean_b * count_b) / count
            # Update the sum of squared differences.
            M2 = M2_a + M2_b + (delta ** 2) * count_a * count_b / count
            return [M2, mean, count]

        null_merge = _null_wrap_merge(ignore_nulls, merge)

        def vectorized_std(block: Block[T]) -> AggType:
            block_acc = BlockAccessor.for_block(block)
            count = block_acc.count(on)
            if count == 0 or count is None:
                # Empty or all null.
                return None
            sum_ = block_acc.sum(on, ignore_nulls)
            if sum_ is None:
                # ignore_nulls=False and at least one null.
                return None
            mean = sum_ / count
            M2 = block_acc.sum_of_squared_diffs_from_mean(on, ignore_nulls, mean)
            return [M2, mean, count]

        def finalize(a: List[float]):
            # Compute the final standard deviation from the accumulated
            # sum of squared differences from current mean and the count.
            M2, mean, count = a
            if count < 2:
                return 0.0
            return math.sqrt(M2 / (count - ddof))

        super().__init__(
            init=_null_wrap_init(lambda k: [0, 0, 0]),
            merge=null_merge,
            accumulate_block=_null_wrap_accumulate_block(
                ignore_nulls,
                vectorized_std,
                null_merge,
            ),
            finalize=_null_wrap_finalize(finalize),
            name=(f"std({str(on)})"),
        )
コード例 #2
0
ファイル: aggregate.py プロジェクト: tchordia/ray
    def __init__(self, on: Optional[KeyFn] = None, ignore_nulls: bool = True):
        self._set_key_fn(on)

        null_merge = _null_wrap_merge(
            ignore_nulls, lambda a1, a2: [a1[0] + a2[0], a1[1] + a2[1]])

        def vectorized_mean(block: Block[T]) -> AggType:
            block_acc = BlockAccessor.for_block(block)
            count = block_acc.count(on)
            if count == 0 or count is None:
                # Empty or all null.
                return None
            sum_ = block_acc.sum(on, ignore_nulls)
            if sum_ is None:
                # ignore_nulls=False and at least one null.
                return None
            return [sum_, count]

        super().__init__(
            init=_null_wrap_init(lambda k: [0, 0]),
            merge=null_merge,
            accumulate_block=_null_wrap_accumulate_block(
                ignore_nulls,
                vectorized_mean,
                null_merge,
            ),
            finalize=_null_wrap_finalize(lambda a: a[0] / a[1]),
            name=(f"mean({str(on)})"),
        )
コード例 #3
0
    def __init__(self, on: Optional[KeyFn] = None, ignore_nulls: bool = True):
        self._set_key_fn(on)

        null_merge = _null_wrap_merge(ignore_nulls, max)

        super().__init__(
            init=_null_wrap_init(lambda k: float("-inf")),
            merge=null_merge,
            accumulate_block=_null_wrap_accumulate_block(
                ignore_nulls,
                lambda block: BlockAccessor.for_block(block).max(on, ignore_nulls),
                null_merge,
            ),
            finalize=_null_wrap_finalize(lambda a: a),
            name=(f"max({str(on)})"),
        )