def infer_return_type( f) -> typing.Union[SeriesType, DataFrameType, ScalarType, UnknownType]: """ >>> def func() -> int: ... pass >>> infer_return_type(func).tpe LongType >>> def func() -> ks.Series[int]: ... pass >>> infer_return_type(func).tpe LongType >>> def func() -> ks.DataFrame[np.float, str]: ... pass >>> infer_return_type(func).tpe StructType(List(StructField(c0,DoubleType,true),StructField(c1,StringType,true))) >>> def func() -> ks.DataFrame[np.float]: ... pass >>> infer_return_type(func).tpe StructType(List(StructField(c0,DoubleType,true))) >>> def func() -> 'int': ... pass >>> infer_return_type(func).tpe LongType >>> def func() -> 'ks.Series[int]': ... pass >>> infer_return_type(func).tpe LongType >>> def func() -> 'ks.DataFrame[np.float, str]': ... pass >>> infer_return_type(func).tpe StructType(List(StructField(c0,DoubleType,true),StructField(c1,StringType,true))) >>> def func() -> 'ks.DataFrame[np.float]': ... pass >>> infer_return_type(func).tpe StructType(List(StructField(c0,DoubleType,true))) >>> def func() -> ks.DataFrame['a': np.float, 'b': int]: ... pass >>> infer_return_type(func).tpe StructType(List(StructField(a,DoubleType,true),StructField(b,LongType,true))) >>> def func() -> "ks.DataFrame['a': np.float, 'b': int]": ... pass >>> infer_return_type(func).tpe StructType(List(StructField(a,DoubleType,true),StructField(b,LongType,true))) >>> pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]}) >>> def func() -> ks.DataFrame[pdf.dtypes]: ... pass >>> infer_return_type(func).tpe StructType(List(StructField(c0,LongType,true),StructField(c1,LongType,true))) >>> pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]}) >>> def func() -> ks.DataFrame[zip(pdf.columns, pdf.dtypes)]: ... pass >>> infer_return_type(func).tpe StructType(List(StructField(a,LongType,true),StructField(b,LongType,true))) """ # We should re-import to make sure the class 'SeriesType' is not treated as a class # within this module locally. See Series.__class_getitem__ which imports this class # canonically. from databricks.koalas.typedef import SeriesType, NameTypeHolder spec = getfullargspec(f) tpe = spec.annotations.get("return", None) if isinstance(tpe, str): # This type hint can happen when given hints are string to avoid forward reference. tpe = resolve_string_type_hint(tpe) if hasattr(tpe, "__origin__") and (issubclass(tpe.__origin__, SeriesType) or tpe.__origin__ == ks.Series): # TODO: remove "tpe.__origin__ == ks.Series" when we drop Python 3.5 and 3.6. inner = as_spark_type(tpe.__args__[0]) return SeriesType(inner) if hasattr(tpe, "__origin__") and tpe.__origin__ == ks.DataFrame: # When Python version is lower then 3.7. Unwrap it to a Tuple type # hints. tpe = tpe.__args__[0] # Note that, DataFrame type hints will create a Tuple. # Python 3.6 has `__name__`. Python 3.7 and 3.8 have `_name`. # Check if the name is Tuple. name = getattr(tpe, "_name", getattr(tpe, "__name__", None)) if name == "Tuple": tuple_type = tpe if hasattr(tuple_type, "__tuple_params__"): # Python 3.5.0 to 3.5.2 has '__tuple_params__' instead. # See https://github.com/python/cpython/blob/v3.5.2/Lib/typing.py parameters = getattr(tuple_type, "__tuple_params__") else: parameters = getattr(tuple_type, "__args__") if len(parameters) > 0 and all( isclass(p) and issubclass(p, NameTypeHolder) for p in parameters): names = [ p.name for p in parameters if issubclass(p, NameTypeHolder) ] types = [ p.tpe for p in parameters if issubclass(p, NameTypeHolder) ] return DataFrameType([as_spark_type(t) for t in types], names) return DataFrameType([as_spark_type(t) for t in parameters]) inner = as_spark_type(tpe) if inner is None: return UnknownType(tpe) else: return ScalarType(inner)
def infer_return_type( f) -> Union[SeriesType, DataFrameType, ScalarType, UnknownType]: """ Infer the return type from the return type annotation of the given function. The returned type class indicates both dtypes (a pandas only dtype object or a numpy dtype object) and its corresponding Spark DataType. >>> def func() -> int: ... pass >>> inferred = infer_return_type(func) >>> inferred.dtype dtype('int64') >>> inferred.spark_type LongType >>> def func() -> ks.Series[int]: ... pass >>> inferred = infer_return_type(func) >>> inferred.dtype dtype('int64') >>> inferred.spark_type LongType >>> def func() -> ks.DataFrame[np.float, str]: ... pass >>> inferred = infer_return_type(func) >>> inferred.dtypes [dtype('float64'), dtype('<U')] >>> inferred.spark_type StructType(List(StructField(c0,DoubleType,true),StructField(c1,StringType,true))) >>> def func() -> ks.DataFrame[np.float]: ... pass >>> inferred = infer_return_type(func) >>> inferred.dtypes [dtype('float64')] >>> inferred.spark_type StructType(List(StructField(c0,DoubleType,true))) >>> def func() -> 'int': ... pass >>> inferred = infer_return_type(func) >>> inferred.dtype dtype('int64') >>> inferred.spark_type LongType >>> def func() -> 'ks.Series[int]': ... pass >>> inferred = infer_return_type(func) >>> inferred.dtype dtype('int64') >>> inferred.spark_type LongType >>> def func() -> 'ks.DataFrame[np.float, str]': ... pass >>> inferred = infer_return_type(func) >>> inferred.dtypes [dtype('float64'), dtype('<U')] >>> inferred.spark_type StructType(List(StructField(c0,DoubleType,true),StructField(c1,StringType,true))) >>> def func() -> 'ks.DataFrame[np.float]': ... pass >>> inferred = infer_return_type(func) >>> inferred.dtypes [dtype('float64')] >>> inferred.spark_type StructType(List(StructField(c0,DoubleType,true))) >>> def func() -> ks.DataFrame['a': np.float, 'b': int]: ... pass >>> inferred = infer_return_type(func) >>> inferred.dtypes [dtype('float64'), dtype('int64')] >>> inferred.spark_type StructType(List(StructField(a,DoubleType,true),StructField(b,LongType,true))) >>> def func() -> "ks.DataFrame['a': np.float, 'b': int]": ... pass >>> inferred = infer_return_type(func) >>> inferred.dtypes [dtype('float64'), dtype('int64')] >>> inferred.spark_type StructType(List(StructField(a,DoubleType,true),StructField(b,LongType,true))) >>> pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]}) >>> def func() -> ks.DataFrame[pdf.dtypes]: ... pass >>> inferred = infer_return_type(func) >>> inferred.dtypes [dtype('int64'), dtype('int64')] >>> inferred.spark_type StructType(List(StructField(c0,LongType,true),StructField(c1,LongType,true))) >>> pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]}) >>> def func() -> ks.DataFrame[zip(pdf.columns, pdf.dtypes)]: ... pass >>> inferred = infer_return_type(func) >>> inferred.dtypes [dtype('int64'), dtype('int64')] >>> inferred.spark_type StructType(List(StructField(a,LongType,true),StructField(b,LongType,true))) >>> pdf = pd.DataFrame({("x", "a"): [1, 2, 3], ("y", "b"): [3, 4, 5]}) >>> def func() -> ks.DataFrame[zip(pdf.columns, pdf.dtypes)]: ... pass >>> inferred = infer_return_type(func) >>> inferred.dtypes [dtype('int64'), dtype('int64')] >>> inferred.spark_type StructType(List(StructField((x, a),LongType,true),StructField((y, b),LongType,true))) >>> pdf = pd.DataFrame({"a": [1, 2, 3], "b": pd.Categorical([3, 4, 5])}) >>> def func() -> ks.DataFrame[pdf.dtypes]: ... pass >>> inferred = infer_return_type(func) >>> inferred.dtypes [dtype('int64'), CategoricalDtype(categories=[3, 4, 5], ordered=False)] >>> inferred.spark_type StructType(List(StructField(c0,LongType,true),StructField(c1,LongType,true))) >>> def func() -> ks.DataFrame[zip(pdf.columns, pdf.dtypes)]: ... pass >>> inferred = infer_return_type(func) >>> inferred.dtypes [dtype('int64'), CategoricalDtype(categories=[3, 4, 5], ordered=False)] >>> inferred.spark_type StructType(List(StructField(a,LongType,true),StructField(b,LongType,true))) >>> def func() -> ks.Series[pdf.b.dtype]: ... pass >>> inferred = infer_return_type(func) >>> inferred.dtype CategoricalDtype(categories=[3, 4, 5], ordered=False) >>> inferred.spark_type LongType """ # We should re-import to make sure the class 'SeriesType' is not treated as a class # within this module locally. See Series.__class_getitem__ which imports this class # canonically. from databricks.koalas.typedef import SeriesType, NameTypeHolder spec = getfullargspec(f) tpe = spec.annotations.get("return", None) if isinstance(tpe, str): # This type hint can happen when given hints are string to avoid forward reference. tpe = resolve_string_type_hint(tpe) if hasattr(tpe, "__origin__") and (tpe.__origin__ == ks.DataFrame or tpe.__origin__ == ks.Series): # When Python version is lower then 3.7. Unwrap it to a Tuple/SeriesType type hints. tpe = tpe.__args__[0] if hasattr(tpe, "__origin__") and issubclass(tpe.__origin__, SeriesType): tpe = tpe.__args__[0] if issubclass(tpe, NameTypeHolder): tpe = tpe.tpe dtype, spark_type = koalas_dtype(tpe) return SeriesType(dtype, spark_type) # Note that, DataFrame type hints will create a Tuple. # Python 3.6 has `__name__`. Python 3.7 and 3.8 have `_name`. # Check if the name is Tuple. name = getattr(tpe, "_name", getattr(tpe, "__name__", None)) if name == "Tuple": tuple_type = tpe if hasattr(tuple_type, "__tuple_params__"): # Python 3.5.0 to 3.5.2 has '__tuple_params__' instead. # See https://github.com/python/cpython/blob/v3.5.2/Lib/typing.py parameters = getattr(tuple_type, "__tuple_params__") else: parameters = getattr(tuple_type, "__args__") dtypes, spark_types = zip( *(koalas_dtype(p.tpe) if isclass(p) and issubclass(p, NameTypeHolder) else koalas_dtype(p) for p in parameters)) names = [ p.name if isclass(p) and issubclass(p, NameTypeHolder) else None for p in parameters ] return DataFrameType(list(dtypes), list(spark_types), names) types = koalas_dtype(tpe) if types is None: return UnknownType(tpe) else: return ScalarType(*types)