Exemple #1
0
 def cast_for_truediv(arrow_array: pa.ChunkedArray,
                      pa_object: pa.Array | pa.Scalar) -> pa.ChunkedArray:
     # Ensure int / int -> float mirroring Python/Numpy behavior
     # as pc.divide_checked(int, int) -> int
     if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(
             pa_object.type):
         return arrow_array.cast(pa.float64())
     return arrow_array
Exemple #2
0
def _timestamp_is_rounded(column: pa.ChunkedArray,
                          granularity: DateGranularity) -> bool:
    factor = {
        DateGranularity.SECOND: 1_000_000_000,
        DateGranularity.MINUTE: 1_000_000_000 * 60,
        DateGranularity.HOUR: 1_000_000_000 * 60 * 60,
    }[granularity]
    ints = column.cast(pa.int64())
    return pa.compute.all(
        pa.compute.equal(
            ints, pa.compute.multiply(pa.compute.divide(ints, factor),
                                      factor))).as_py()
Exemple #3
0
def recode_or_decode_dictionary(
        chunked_array: pa.ChunkedArray) -> pa.ChunkedArray:
    """Remove unused/duplicate dictionary values from -- or cast to pa.utf8().

    Workbench disallows unused/duplicate values. Call this function after
    filtering or modifying dictionary values: it returns a valid Workbench
    column given a valid Arrow column.

    Convert to utf8() if dictionary encoding is "bad". ("Bad" currently means,
    "each value is only used once;" but the meaning may change between minor
    versions.)

    Return `chunked_array` if it is already Workbench-valid and dictionary
    encoding is not "bad".
    """
    if chunked_array.num_chunks == 0:
        return pa.chunked_array([], pa.utf8())

    # if chunked_array.num_chunks != 1:
    #     chunked_array = chunked_array.unify_dictionaries()

    if len(chunked_array) - chunked_array.null_count <= len(
            chunked_array.chunks[0].dictionary):
        return chunked_array.cast(pa.utf8())

    dictionary = chunked_array.chunks[0].dictionary

    used = np.zeros(len(dictionary), dtype=bool)
    for chunk in chunked_array.chunks:
        used[pa.compute.filter(chunk.indices, pa.compute.is_valid(
            chunk.indices)).to_numpy()] = True

    if not np.all(used):
        # Nix unused values; then scan for dups
        mapping = dictionary.filter(pa.array(used,
                                             pa.bool_())).dictionary_encode()
        need_recode = True
    else:
        # Scan for dups
        mapping = dictionary.dictionary_encode()
        need_recode = len(mapping.dictionary) < len(dictionary)

    if need_recode:
        chunks = [_recode(chunk, mapping) for chunk in chunked_array.chunks]
        return pa.chunked_array(chunks)

    return chunked_array
def _startof(column: pa.ChunkedArray, unit: str) -> StartofColumnResult:
    factor = pa.scalar(_NS_PER_UNIT[unit], pa.int64())
    timestamp_ints = column.cast(pa.int64())

    # In two's complement, truncation rounds _up_. Subtract before truncating.
    #
    # In decimal, if we're truncating to the nearest 10:
    #
    # 0 => 0
    # -1 => -10
    # -9 => -10
    # -10 => -10
    # -11 => -20
    #
    # ... rule is: subtract 9 from all negative numbers, then truncate.

    negative = pa.compute.less(timestamp_ints, pa.scalar(0, pa.int64()))
    # "offset": -9 for negatives, 0 for others
    offset = pa.compute.multiply(
        negative.cast(pa.int64()),
        pa.scalar(-1 * _NS_PER_UNIT[unit] + 1, pa.int64()))
    # to_truncate may overflow; in that case, to_truncate > timestamp_ints
    to_truncate = pa.compute.add(timestamp_ints, offset)
    truncated = pa.compute.multiply(pa.compute.divide(to_truncate, factor),
                                    factor)

    # Mask of [True, None, True, True, None]
    safe_or_null = pa.compute.or_kleene(
        pa.compute.less_equal(to_truncate, timestamp_ints),
        pa.scalar(None, pa.bool_()))

    truncated_or_null = truncated.filter(safe_or_null,
                                         null_selection_behavior="emit_null")

    return StartofColumnResult(
        column=truncated_or_null.cast(pa.timestamp("ns")),
        truncated=(truncated_or_null.null_count > column.null_count),
    )