def _check_and_cast_columns_with_other( source_col: ColumnBase, other: Union[ScalarLike, ColumnBase], inplace: bool, ) -> Tuple[ColumnBase, Union[ScalarLike, ColumnBase]]: """ Returns type-casted column `source_col` & scalar `other_scalar` based on `inplace` parameter. """ if cudf.utils.dtypes.is_categorical_dtype(source_col.dtype): return source_col, other if cudf.utils.dtypes.is_scalar(other): device_obj = _normalize_scalars(source_col, other) else: device_obj = other if other is None: return source_col, device_obj elif cudf.utils.dtypes.is_mixed_with_object_dtype(device_obj, source_col): raise TypeError( "cudf does not support mixed types, please type-cast " "the column of dataframe/series and other " "to same dtypes." ) if inplace: if not cudf.utils.dtypes._can_cast(device_obj.dtype, source_col.dtype): warnings.warn( f"Type-casting from {device_obj.dtype} " f"to {source_col.dtype}, there could be potential data loss" ) return source_col, device_obj.astype(source_col.dtype) else: if ( cudf.utils.dtypes.is_scalar(other) and cudf.utils.dtypes._is_non_decimal_numeric_dtype( source_col.dtype ) and cudf.utils.dtypes._can_cast(other, source_col.dtype) ): common_dtype = source_col.dtype return ( source_col.astype(common_dtype), cudf.Scalar(other, dtype=common_dtype), ) else: common_dtype = cudf.utils.dtypes.find_common_type( [ source_col.dtype, np.min_scalar_type(other) if cudf.utils.dtypes.is_scalar(other) else other.dtype, ] ) if cudf.utils.dtypes.is_scalar(device_obj): device_obj = cudf.Scalar(other, dtype=common_dtype) else: device_obj = device_obj.astype(common_dtype) return source_col.astype(common_dtype), device_obj
def _match_categorical_dtypes(lcol: ColumnBase, rcol: ColumnBase, how: str) -> Tuple[ColumnBase, ColumnBase]: # cast the keys lcol and rcol to a common dtype # when at least one of them is a categorical type ltype, rtype = lcol.dtype, rcol.dtype if isinstance(lcol, cudf.core.column.CategoricalColumn) and isinstance( rcol, cudf.core.column.CategoricalColumn): # if both are categoricals, logic is complicated: return _match_categorical_dtypes_both(lcol, rcol, how) if isinstance(ltype, CategoricalDtype): if how in {"left", "leftsemi", "leftanti"}: return lcol, rcol.astype(ltype) common_type = ltype.categories.dtype elif isinstance(rtype, CategoricalDtype): common_type = rtype.categories.dtype return lcol.astype(common_type), rcol.astype(common_type)
def _match_join_keys(lcol: ColumnBase, rcol: ColumnBase, how: str) -> Tuple[ColumnBase, ColumnBase]: # returns the common dtype that lcol and rcol should be casted to, # before they can be used as left and right join keys. # If no casting is necessary, returns None common_type = None # cast the keys lcol and rcol to a common dtype ltype = lcol.dtype rtype = rcol.dtype # if either side is categorical, different logic if isinstance(ltype, CategoricalDtype) or isinstance( rtype, CategoricalDtype): return _match_categorical_dtypes(lcol, rcol, how) if pd.api.types.is_dtype_equal(ltype, rtype): return lcol, rcol if isinstance(ltype, cudf.Decimal64Dtype) or isinstance( rtype, cudf.Decimal64Dtype): raise TypeError( "Decimal columns can only be merged with decimal columns " "of the same precision and scale") if (np.issubdtype(ltype, np.number)) and (np.issubdtype(rtype, np.number)): common_type = (max(ltype, rtype) if ltype.kind == rtype.kind else np.find_common_type([], (ltype, rtype))) elif np.issubdtype(ltype, np.datetime64) and np.issubdtype( rtype, np.datetime64): common_type = max(ltype, rtype) if how == "left": if rcol.fillna(0).can_cast_safely(ltype): return lcol, rcol.astype(ltype) else: warnings.warn(f"Can't safely cast column from {rtype} to {ltype}, " "upcasting to {common_type}.") return lcol.astype(common_type), rcol.astype(common_type)
def _safe_cast_to_int(col: ColumnBase, dtype: DtypeObj) -> ColumnBase: """ Cast given NumericalColumn to given integer dtype safely. """ assert is_integer_dtype(dtype) if col.dtype == dtype: return col new_col = col.astype(dtype) if (new_col == col).all(): return new_col else: raise TypeError( f"Cannot safely cast non-equivalent " f"{col.dtype.type.__name__} to {np.dtype(dtype).type.__name__}")
def _match_join_keys(lcol: ColumnBase, rcol: ColumnBase, how: str) -> Tuple[ColumnBase, ColumnBase]: # Casts lcol and rcol to a common dtype for use as join keys. If no casting # is necessary, they are returned as is. common_type = None # cast the keys lcol and rcol to a common dtype ltype = lcol.dtype rtype = rcol.dtype # if either side is categorical, different logic left_is_categorical = isinstance(ltype, CategoricalDtype) right_is_categorical = isinstance(rtype, CategoricalDtype) if left_is_categorical and right_is_categorical: return _match_categorical_dtypes_both(cast(CategoricalColumn, lcol), cast(CategoricalColumn, rcol), how) elif left_is_categorical or right_is_categorical: if left_is_categorical: if how in {"left", "leftsemi", "leftanti"}: return lcol, rcol.astype(ltype) common_type = ltype.categories.dtype else: common_type = rtype.categories.dtype return lcol.astype(common_type), rcol.astype(common_type) if is_dtype_equal(ltype, rtype): return lcol, rcol if is_decimal_dtype(ltype) or is_decimal_dtype(rtype): raise TypeError( "Decimal columns can only be merged with decimal columns " "of the same precision and scale") if (np.issubdtype(ltype, np.number)) and (np.issubdtype(rtype, np.number)): common_type = (max(ltype, rtype) if ltype.kind == rtype.kind else np.find_common_type([], (ltype, rtype))) elif np.issubdtype(ltype, np.datetime64) and np.issubdtype( rtype, np.datetime64): common_type = max(ltype, rtype) if how == "left": if rcol.fillna(0).can_cast_safely(ltype): return lcol, rcol.astype(ltype) else: warnings.warn(f"Can't safely cast column from {rtype} to {ltype}, " f"upcasting to {common_type}.") return lcol.astype(common_type), rcol.astype(common_type)