def _return_type(self) -> DataType: hint = self._return_type_hint # The logic is simple for now, because it corresponds to the default # case: continuous predictions # TODO: do something smarter, for example when there is a sklearn.Classifier (it should # return an integer or a categorical) # We can do the same for pytorch/tensorflow/keras models by looking at the output types. # However, this is probably better done in mlflow than here. if hint == "infer" or not hint: hint = np.float64 return as_spark_type(hint)
def test_as_spark_type_extension_float_dtypes(self): from pandas import Float32Dtype, Float64Dtype type_mapper = { Float32Dtype(): FloatType(), Float64Dtype(): DoubleType(), } for extension_dtype, spark_type in type_mapper.items(): self.assertEqual(as_spark_type(extension_dtype), spark_type) self.assertEqual(koalas_dtype(extension_dtype), (extension_dtype, spark_type))
def test_as_spark_type_extension_object_dtypes(self): from pandas import BooleanDtype, StringDtype type_mapper = { BooleanDtype(): BooleanType(), StringDtype(): StringType(), } for extension_dtype, spark_type in type_mapper.items(): self.assertEqual(as_spark_type(extension_dtype), spark_type) self.assertEqual(koalas_dtype(extension_dtype), (extension_dtype, spark_type))
def test_as_spark_type_extension_dtypes(self): from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype type_mapper = { Int8Dtype(): ByteType(), Int16Dtype(): ShortType(), Int32Dtype(): IntegerType(), Int64Dtype(): LongType(), } for extension_dtype, spark_type in type_mapper.items(): self.assertEqual(as_spark_type(extension_dtype), spark_type) self.assertEqual(koalas_dtype(extension_dtype), (extension_dtype, spark_type))
def rsub(self, left, right) -> Union["Series", "Index"]: # Note that timestamp subtraction casts arguments to integer. This is to mimic pandas's # behaviors. pandas returns 'timedelta64[ns]' from 'datetime64[ns]'s subtraction. msg = ( "Note that there is a behavior difference of timestamp subtraction. " "The timestamp subtraction returns an integer in seconds, " "whereas pandas returns 'timedelta64[ns]'." ) if isinstance(right, datetime.datetime): warnings.warn(msg, UserWarning) return -(left.astype("long") - F.lit(right).cast(as_spark_type("long"))) else: raise TypeError("datetime subtraction can only be applied to datetime series.")
def rsub(self, left: T_IndexOps, right: Any) -> IndexOpsLike: # Note that timestamp subtraction casts arguments to integer. This is to mimic pandas's # behaviors. pandas returns 'timedelta64[ns]' from 'datetime64[ns]'s subtraction. msg = ( "Note that there is a behavior difference of timestamp subtraction. " "The timestamp subtraction returns an integer in seconds, " "whereas pandas returns 'timedelta64[ns]'.") if isinstance(right, datetime.datetime): warnings.warn(msg, UserWarning) return cast( IndexOpsLike, left.spark.transform(lambda scol: F.lit(right).cast( as_spark_type("long")) - scol.astype("long")), ) else: raise TypeError( "datetime subtraction can only be applied to datetime series.")
def test_as_spark_type_koalas_dtype(self): type_mapper = { # binary np.character: (np.character, BinaryType()), np.bytes_: (np.bytes_, BinaryType()), np.string_: (np.bytes_, BinaryType()), bytes: (np.bytes_, BinaryType()), # integer np.int8: (np.int8, ByteType()), np.byte: (np.int8, ByteType()), np.int16: (np.int16, ShortType()), np.int32: (np.int32, IntegerType()), np.int64: (np.int64, LongType()), np.int: (np.int64, LongType()), int: (np.int64, LongType()), # floating np.float32: (np.float32, FloatType()), np.float: (np.float64, DoubleType()), np.float64: (np.float64, DoubleType()), float: (np.float64, DoubleType()), # string np.str: (np.unicode_, StringType()), np.unicode_: (np.unicode_, StringType()), str: (np.unicode_, StringType()), # bool np.bool: (np.bool, BooleanType()), bool: (np.bool, BooleanType()), # datetime np.datetime64: (np.datetime64, TimestampType()), datetime.datetime: (np.dtype("datetime64[ns]"), TimestampType()), # DateType datetime.date: (np.dtype("object"), DateType()), # DecimalType decimal.Decimal: (np.dtype("object"), DecimalType(38, 18)), # ArrayType np.ndarray: (np.dtype("object"), ArrayType(StringType())), List[bytes]: (np.dtype("object"), ArrayType(BinaryType())), List[np.character]: (np.dtype("object"), ArrayType(BinaryType())), List[np.bytes_]: (np.dtype("object"), ArrayType(BinaryType())), List[np.string_]: (np.dtype("object"), ArrayType(BinaryType())), List[bool]: (np.dtype("object"), ArrayType(BooleanType())), List[np.bool]: (np.dtype("object"), ArrayType(BooleanType())), List[datetime.date]: (np.dtype("object"), ArrayType(DateType())), List[np.int8]: (np.dtype("object"), ArrayType(ByteType())), List[np.byte]: (np.dtype("object"), ArrayType(ByteType())), List[decimal.Decimal]: (np.dtype("object"), ArrayType(DecimalType(38, 18))), List[float]: (np.dtype("object"), ArrayType(DoubleType())), List[np.float]: (np.dtype("object"), ArrayType(DoubleType())), List[np.float64]: (np.dtype("object"), ArrayType(DoubleType())), List[np.float32]: (np.dtype("object"), ArrayType(FloatType())), List[np.int32]: (np.dtype("object"), ArrayType(IntegerType())), List[int]: (np.dtype("object"), ArrayType(LongType())), List[np.int]: (np.dtype("object"), ArrayType(LongType())), List[np.int64]: (np.dtype("object"), ArrayType(LongType())), List[np.int16]: (np.dtype("object"), ArrayType(ShortType())), List[str]: (np.dtype("object"), ArrayType(StringType())), List[np.unicode_]: (np.dtype("object"), ArrayType(StringType())), List[datetime.datetime]: (np.dtype("object"), ArrayType(TimestampType())), List[np.datetime64]: (np.dtype("object"), ArrayType(TimestampType())), # CategoricalDtype CategoricalDtype(categories=["a", "b", "c"]): ( CategoricalDtype(categories=["a", "b", "c"]), LongType(), ), } for numpy_or_python_type, (dtype, spark_type) in type_mapper.items(): self.assertEqual(as_spark_type(numpy_or_python_type), spark_type) self.assertEqual(koalas_dtype(numpy_or_python_type), (dtype, spark_type)) with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."): as_spark_type(np.dtype("uint64")) with self.assertRaisesRegex(TypeError, "Type object was not understood."): as_spark_type(np.dtype("object")) with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."): koalas_dtype(np.dtype("uint64")) with self.assertRaisesRegex(TypeError, "Type object was not understood."): koalas_dtype(np.dtype("object"))
def test_as_spark_type_pandas_on_spark_dtype(self): type_mapper = { # binary np.character: (np.character, BinaryType()), np.bytes_: (np.bytes_, BinaryType()), np.string_: (np.bytes_, BinaryType()), bytes: (np.bytes_, BinaryType()), # integer np.int8: (np.int8, ByteType()), np.byte: (np.int8, ByteType()), np.int16: (np.int16, ShortType()), np.int32: (np.int32, IntegerType()), np.int64: (np.int64, LongType()), np.int: (np.int64, LongType()), int: (np.int64, LongType()), # floating np.float32: (np.float32, FloatType()), np.float: (np.float64, DoubleType()), np.float64: (np.float64, DoubleType()), float: (np.float64, DoubleType()), # string np.str: (np.unicode_, StringType()), np.unicode_: (np.unicode_, StringType()), str: (np.unicode_, StringType()), # bool np.bool: (np.bool, BooleanType()), bool: (np.bool, BooleanType()), # datetime np.datetime64: (np.datetime64, TimestampType()), datetime.datetime: (np.dtype("datetime64[ns]"), TimestampType()), # DateType datetime.date: (np.dtype("object"), DateType()), # DecimalType decimal.Decimal: (np.dtype("object"), DecimalType(38, 18)), # ArrayType np.ndarray: (np.dtype("object"), ArrayType(StringType())), # CategoricalDtype CategoricalDtype(categories=["a", "b", "c"]): ( CategoricalDtype(categories=["a", "b", "c"]), LongType(), ), } for numpy_or_python_type, (dtype, spark_type) in type_mapper.items(): self.assertEqual(as_spark_type(numpy_or_python_type), spark_type) self.assertEqual(pandas_on_spark_type(numpy_or_python_type), (dtype, spark_type)) if isinstance(numpy_or_python_type, CategoricalDtype): # Nested CategoricalDtype is not yet supported. continue self.assertEqual(as_spark_type(List[numpy_or_python_type]), ArrayType(spark_type)) self.assertEqual( pandas_on_spark_type(List[numpy_or_python_type]), (np.dtype("object"), ArrayType(spark_type)), ) # For NumPy typing, NumPy version should be 1.21+ and Python version should be 3.8+ if sys.version_info >= (3, 8) and LooseVersion( np.__version__) >= LooseVersion("1.21"): import numpy.typing as ntp self.assertEqual( as_spark_type(ntp.NDArray[numpy_or_python_type]), ArrayType(spark_type)) self.assertEqual( pandas_on_spark_type(ntp.NDArray[numpy_or_python_type]), (np.dtype("object"), ArrayType(spark_type)), ) with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."): as_spark_type(np.dtype("uint64")) with self.assertRaisesRegex(TypeError, "Type object was not understood."): as_spark_type(np.dtype("object")) with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."): pandas_on_spark_type(np.dtype("uint64")) with self.assertRaisesRegex(TypeError, "Type object was not understood."): pandas_on_spark_type(np.dtype("object"))