Exemple #1
0
def _assert_series_inner(
    left: Series,
    right: Series,
    check_dtype: bool,
    check_exact: bool,
    atol: float,
    rtol: float,
    obj: str,
) -> None:
    """
    Compares Series dtype + values
    """
    try:
        can_be_subtracted = hasattr(dtype_to_py_type(left.dtype), "__sub__")
    except NotImplementedError:
        can_be_subtracted = False

    check_exact = check_exact or not can_be_subtracted or left.dtype == Boolean

    if check_dtype:
        if left.dtype != right.dtype:
            raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype)

    if check_exact:
        if (left != right).sum() != 0:
            raise_assert_detail(obj,
                                "Exact value mismatch",
                                left=list(left),
                                right=list(right))
    else:
        if ((left - right).abs() > (atol + rtol * right.abs())).sum() != 0:
            raise_assert_detail(obj,
                                "Value mismatch",
                                left=list(left),
                                right=list(right))
Exemple #2
0
def _assert_series_inner(
    left: pli.Series,
    right: pli.Series,
    check_dtype: bool,
    check_exact: bool,
    nans_compare_equal: bool,
    atol: float,
    rtol: float,
    obj: str,
) -> None:
    """
    Compares Series dtype + values
    """
    try:
        can_be_subtracted = hasattr(dtype_to_py_type(left.dtype), "__sub__")
    except NotImplementedError:
        can_be_subtracted = False

    check_exact = check_exact or not can_be_subtracted or left.dtype == Boolean
    if check_dtype:
        if left.dtype != right.dtype:
            raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype)

    # create mask of which (if any) values are unequal
    unequal = left != right
    if unequal.any() and nans_compare_equal and left.dtype in (Float32,
                                                               Float64):
        # handle NaN values (which compare unequal to themselves)
        unequal = unequal & ~(
            (left.is_nan() & right.is_nan()).fill_null(pli.lit(False)))

    # assert exact, or with tolerance
    if unequal.any():
        if check_exact:
            raise_assert_detail(obj,
                                "Exact value mismatch",
                                left=list(left),
                                right=list(right))
        else:
            # apply check with tolerance, but only to the known-unequal matches
            left, right = left.filter(unequal), right.filter(unequal)
            if ((left - right).abs() > (atol + rtol * right.abs())).sum() != 0:
                raise_assert_detail(obj,
                                    "Value mismatch",
                                    left=list(left),
                                    right=list(right))