def test_bad_tag(): # Will not be able to load this with pytest.raises(ValueError): lt = LiteralType(simple=SimpleType.STRUCT, metadata={"pb_type": "bad.tag"}) TypeEngine.guess_python_type(lt) # Doesn't match pb field key with pytest.raises(ValueError): lt = LiteralType(simple=SimpleType.STRUCT, metadata={}) TypeEngine.guess_python_type(lt)
def test_create_native_named_tuple(): ctx = FlyteContextManager.current_context() t = create_native_named_tuple(ctx, promises=None, entity_interface=Interface()) assert t is None p1 = Promise(var="x", val=TypeEngine.to_literal( ctx, 1, int, LiteralType(simple=SimpleType.INTEGER))) p2 = Promise(var="y", val=TypeEngine.to_literal( ctx, 2, int, LiteralType(simple=SimpleType.INTEGER))) t = create_native_named_tuple( ctx, promises=p1, entity_interface=Interface(outputs={"x": int})) assert t assert t == 1 t = create_native_named_tuple(ctx, promises=[], entity_interface=Interface()) assert t is None t = create_native_named_tuple(ctx, promises=[p1, p2], entity_interface=Interface(outputs={ "x": int, "y": int })) assert t assert t == (1, 2) t = create_native_named_tuple(ctx, promises=[p1, p2], entity_interface=Interface( outputs={ "x": int, "y": int }, output_tuple_name="Tup")) assert t assert t == (1, 2) assert t.__class__.__name__ == "Tup" with pytest.raises(KeyError): create_native_named_tuple(ctx, promises=[p1, p2], entity_interface=Interface( outputs={"x": int}, output_tuple_name="Tup"))
def test_engine_file_output(): basic_blob_type = _core_types.BlobType( format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) fs = FileAccessProvider(local_sandbox_dir="/tmp/flytetesting") with context_manager.FlyteContext.current_context( ).new_file_access_context(file_access_provider=fs) as ctx: # Write some text to a file not in that directory above test_file_location = "/tmp/sample.txt" with open(test_file_location, "w") as fh: fh.write("Hello World\n") lit = TypeEngine.to_literal(ctx, test_file_location, os.PathLike, LiteralType(blob=basic_blob_type)) # Since we're using local as remote, we should be able to just read the file from the 'remote' location. with open(lit.scalar.blob.uri, "r") as fh: assert fh.readline() == "Hello World\n" # We should also be able to turn the thing back into regular python native thing. redownloaded_local_file_location = TypeEngine.to_python_value( ctx, lit, os.PathLike) with open(redownloaded_local_file_location, "r") as fh: assert fh.readline() == "Hello World\n"
def test_get_literal_type(): tf = FlytePickleTransformer() lt = tf.get_literal_type(FlytePickle) assert lt == LiteralType( blob=BlobType( format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE ) )
def test_parameter_ranges_transformer(): t = ParameterRangesTransformer() assert t.get_literal_type(ParameterRangeOneOf) == LiteralType(simple=SimpleType.STRUCT) o = ParameterRangeOneOf(param=IntegerParameterRange(10, 0, 1)) ctx = FlyteContext.current_context() lit = t.to_literal(ctx, python_val=o, python_type=ParameterRangeOneOf, expected=None) assert lit is not None assert lit.scalar.generic is not None ro = t.to_python_value(ctx, lit, ParameterRangeOneOf) assert ro is not None assert ro == o
def get_literal_type( self, t: typing.Union[Type[StructuredDataset], typing.Any]) -> LiteralType: """ Provide a concrete implementation so that writers of custom dataframe handlers since there's nothing that special about the literal type. Any dataframe type will always be associated with the structured dataset type. The other aspects of it - columns, external schema type, etc. can be read from associated metadata. :param t: The python dataframe type, which is mostly ignored. """ return LiteralType(structured_dataset_type=self._get_dataset_type(t))
def get_literal_type(self, t: Type[GreatExpectationsType]) -> LiteralType: datatype = GreatExpectationsTypeTransformer.get_config(t)[0] if issubclass(datatype, str): return LiteralType(simple=_type_models.SimpleType.STRING, metadata={}) elif issubclass(datatype, FlyteFile): return FlyteFilePathTransformer().get_literal_type(datatype) elif issubclass(datatype, FlyteSchema): return FlyteSchemaTransformer().get_literal_type(datatype) else: raise TypeError(f"{datatype} is not a supported type")
def test_structured_dataset(): x = LiteralType( structured_dataset_type=StructuredDatasetType( columns=[ StructuredDatasetType.DatasetColumn("a", str_type), StructuredDatasetType.DatasetColumn("b", int_type), ], format="abc", external_schema_type="zzz", external_schema_bytes=b"zzz", ) ) assert _are_types_castable(x, x)
def test_hpoconfig_transformer(): t = HPOTuningJobConfigTransformer() assert t.get_literal_type(HyperparameterTuningJobConfig) == LiteralType(simple=SimpleType.STRUCT) o = HyperparameterTuningJobConfig( tuning_strategy=1, tuning_objective=HyperparameterTuningObjective( objective_type=HyperparameterTuningObjectiveType.MINIMIZE, metric_name="x", ), training_job_early_stopping_type=TrainingJobEarlyStoppingType.OFF, ) ctx = FlyteContext.current_context() lit = t.to_literal(ctx, python_val=o, python_type=HyperparameterTuningJobConfig, expected=None) assert lit is not None assert lit.scalar.generic is not None ro = t.to_python_value(ctx, lit, HyperparameterTuningJobConfig) assert ro is not None assert ro == o
def get_literal_type(self, t: Type[_params.ParameterRangeOneOf]) -> LiteralType: return LiteralType(simple=SimpleType.STRUCT, metadata=None)
def get_literal_type(self, t: Type[_hpo_job_model.HyperparameterTuningJobConfig]) -> LiteralType: return LiteralType(simple=SimpleType.STRUCT, metadata=None)
def get_literal_type(self, t: Type[T]) -> LiteralType: values = [v.value for v in t] if not isinstance(values[0], str): raise AssertionError("Only EnumTypes with value of string are supported") return LiteralType(enum_type=_core_types.EnumType(values=values))
def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType( simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)})
def get_literal_type(self, t: Type[pandas.DataFrame]) -> LiteralType: return LiteralType(schema=self._get_schema_type())
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: """ The only tricky thing with converting a Literal (say the output of an earlier task), to a Python value at the start of a task execution, is the column subsetting behavior. For example, if you have, def t1() -> Annotated[StructuredDataset, kwtypes(col_a=int, col_b=float)]: ... def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ... where t2(in_a=t1()), when t2 does in_a.open(pd.DataFrame).all(), it should get a DataFrame with only one column. +-----------------------------+-----------------------------------------+--------------------------------------+ | | StructuredDatasetType of the incoming Literal | +-----------------------------+-----------------------------------------+--------------------------------------+ | StructuredDatasetType | Has columns defined | [] columns or None | | of currently running task | | | +=============================+=========================================+======================================+ | Has columns | The StructuredDatasetType passed to the decoder will have the columns | | defined | as defined by the type annotation of the currently running task. | | | | | | Decoders **should** then subset the incoming data to the columns requested. | | | | +-----------------------------+-----------------------------------------+--------------------------------------+ | [] columns or None | StructuredDatasetType passed to decoder | StructuredDatasetType passed to the | | | will have the columns from the incoming | decoder will have an empty list of | | | Literal. This is the scenario where | columns. | | | the Literal returned by the running | | | | task will have more information than | | | | the running task's signature. | | +-----------------------------+-----------------------------------------+--------------------------------------+ """ # Detect annotations and extract out all the relevant information that the user might supply expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format( expected_python_type) # The literal that we get in might be an old FlyteSchema. # We'll continue to support this for the time being. There is some duplicated logic here but let's # keep it copy/pasted for clarity if lv.scalar.schema is not None: schema_columns = lv.scalar.schema.type.columns # See the repeated logic below for comments if column_dict is None or len(column_dict) == 0: final_dataset_columns = [] if schema_columns is not None and schema_columns != []: for c in schema_columns: final_dataset_columns.append( StructuredDatasetType.DatasetColumn( name=c.name, literal_type=LiteralType( simple= convert_schema_type_to_structured_dataset_type( c.type), ), )) # Dataframe will always be serialized to parquet file by FlyteSchema transformer new_sdt = StructuredDatasetType(columns=final_dataset_columns, format=PARQUET) else: final_dataset_columns = self._convert_ordered_dict_of_columns_to_list( column_dict) # Dataframe will always be serialized to parquet file by FlyteSchema transformer new_sdt = StructuredDatasetType(columns=final_dataset_columns, format=PARQUET) metad = literals.StructuredDatasetMetadata( structured_dataset_type=new_sdt) sd_literal = literals.StructuredDataset( uri=lv.scalar.schema.uri, metadata=metad, ) if issubclass(expected_python_type, StructuredDataset): sd = StructuredDataset(dataframe=None, metadata=metad) sd._literal_sd = sd_literal return sd else: return self.open_as(ctx, sd_literal, expected_python_type, metad) # Start handling for StructuredDataset scalars, first look at the columns incoming_columns = lv.scalar.structured_dataset.metadata.structured_dataset_type.columns # If the incoming literal, also doesn't have columns, then we just have an empty list, so initialize here final_dataset_columns = [] # If the current running task's input does not have columns defined, or has an empty list of columns if column_dict is None or len(column_dict) == 0: # but if it does, then we just copy it over if incoming_columns is not None and incoming_columns != []: final_dataset_columns = incoming_columns.copy() # If the current running task's input does have columns defined else: final_dataset_columns = self._convert_ordered_dict_of_columns_to_list( column_dict) new_sdt = StructuredDatasetType( columns=final_dataset_columns, format=lv.scalar.structured_dataset.metadata. structured_dataset_type.format, external_schema_type=lv.scalar.structured_dataset.metadata. structured_dataset_type.external_schema_type, external_schema_bytes=lv.scalar.structured_dataset.metadata. structured_dataset_type.external_schema_bytes, ) metad = StructuredDatasetMetadata(structured_dataset_type=new_sdt) # A StructuredDataset type, for example # t1(input_a: StructuredDataset) # or # t1(input_a: Annotated[StructuredDataset, my_cols]) if issubclass(expected_python_type, StructuredDataset): sd = expected_python_type( dataframe=None, # Note here that the type being passed in metadata=metad, ) sd._literal_sd = lv.scalar.structured_dataset sd.file_format = metad.structured_dataset_type.format return sd # If the requested type was not a StructuredDataset, then it means it was a plain dataframe type, which means # we should do the opening/downloading and whatever else it might entail right now. No iteration option here. return self.open_as(ctx, lv.scalar.structured_dataset, df_type=expected_python_type, updated_metadata=metad)
def get_literal_type(self, t: Type[DatasetProfileView]) -> LiteralType: return LiteralType(blob=self._TYPE_INFO)
def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType( blob=_core_types.BlobType( format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE ) )
def get_literal_type(self, t: Type[pandera.typing.DataFrame]) -> LiteralType: return LiteralType(schema=self._get_schema_type(t))
def get_literal_type(self, t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> LiteralType: return LiteralType(blob=self._blob_type(format=FlyteFilePathTransformer.get_format(t)))
def get_literal_type(self, t: Type[DoltTable]) -> LiteralType: return LiteralType(simple=_type_models.SimpleType.STRUCT, metadata={})
def test_collection_union(): # a list of str is a list of (str | int) assert _are_types_castable(LiteralType(collection_type=str_type), LiteralType(collection_type=str_or_int)) # a list of int is a list of (str | int) assert _are_types_castable(LiteralType(collection_type=int_type), LiteralType(collection_type=str_or_int)) # a list of str or a list of int is a list of (str | int) assert _are_types_castable( LiteralType( union_type=UnionType([LiteralType(collection_type=int_type), LiteralType(collection_type=str_type)]) ), LiteralType(collection_type=str_or_int), ) assert _are_types_castable( LiteralType( union_type=UnionType([LiteralType(collection_type=int_type), LiteralType(collection_type=str_type)]) ), LiteralType(collection_type=str_or_int_or_bool), ) # a list of str or a list of bool is not a list of (str | int) assert not _are_types_castable( LiteralType( union_type=UnionType([LiteralType(collection_type=int_type), LiteralType(collection_type=bool_type)]) ), LiteralType(collection_type=str_or_int), ) # not the other way around assert not _are_types_castable(LiteralType(collection_type=str_or_int), LiteralType(collection_type=str_type)) assert not _are_types_castable(LiteralType(collection_type=str_or_int), LiteralType(collection_type=int_type))
def test_dict_transformer(): d = DictTransformer() def assert_struct(lit: LiteralType): assert lit is not None assert lit.simple == SimpleType.STRUCT def recursive_assert(lit: LiteralType, expected: LiteralType, expected_depth: int = 1, curr_depth: int = 0): assert curr_depth <= expected_depth assert lit is not None if lit.map_value_type is None: assert lit == expected return recursive_assert(lit.map_value_type, expected, expected_depth, curr_depth + 1) # Type inference assert_struct(d.get_literal_type(dict)) assert_struct(d.get_literal_type(typing.Dict[int, int])) recursive_assert(d.get_literal_type(typing.Dict[str, str]), LiteralType(simple=SimpleType.STRING)) recursive_assert(d.get_literal_type(typing.Dict[str, int]), LiteralType(simple=SimpleType.INTEGER)) recursive_assert(d.get_literal_type(typing.Dict[str, datetime.datetime]), LiteralType(simple=SimpleType.DATETIME)) recursive_assert(d.get_literal_type(typing.Dict[str, datetime.timedelta]), LiteralType(simple=SimpleType.DURATION)) recursive_assert(d.get_literal_type(typing.Dict[str, dict]), LiteralType(simple=SimpleType.STRUCT)) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, str]]), LiteralType(simple=SimpleType.STRING), expected_depth=2, ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[int, str]]), LiteralType(simple=SimpleType.STRUCT), expected_depth=2, ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]]), LiteralType(simple=SimpleType.STRING), expected_depth=3, ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, typing.Dict[str, dict]]]), LiteralType(simple=SimpleType.STRUCT), expected_depth=3, ) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, typing.Dict[int, dict]]]), LiteralType(simple=SimpleType.STRUCT), expected_depth=2, ) ctx = FlyteContext.current_context() lit = d.to_literal(ctx, {}, typing.Dict, LiteralType(SimpleType.STRUCT)) pv = d.to_python_value(ctx, lit, typing.Dict) assert pv == {} # Literal to python with pytest.raises(TypeError): d.to_python_value( ctx, Literal(scalar=Scalar(primitive=Primitive(integer=10))), dict) with pytest.raises(TypeError): d.to_python_value(ctx, Literal(), dict) with pytest.raises(TypeError): d.to_python_value(ctx, Literal(map=LiteralMap(literals={"x": None})), dict) with pytest.raises(TypeError): d.to_python_value(ctx, Literal(map=LiteralMap(literals={"x": None})), typing.Dict[int, str]) d.to_python_value( ctx, Literal(map=LiteralMap( literals={ "x": Literal(scalar=Scalar(primitive=Primitive(integer=1))) })), typing.Dict[str, int], )
from flytekit.core.type_engine import _are_types_castable from flytekit.models.annotation import TypeAnnotation from flytekit.models.core.types import EnumType from flytekit.models.types import LiteralType, SimpleType, StructuredDatasetType, TypeStructure, UnionType str_type = LiteralType(simple=SimpleType.STRING) int_type = LiteralType(simple=SimpleType.INTEGER) none_type = LiteralType(simple=SimpleType.NONE) bool_type = LiteralType(simple=SimpleType.BOOLEAN) str_or_int = LiteralType(union_type=UnionType([str_type, int_type])) int_or_str = LiteralType(union_type=UnionType([int_type, str_type])) str_or_int_or_bool = LiteralType(union_type=UnionType([str_type, int_type, bool_type])) optional_str = LiteralType(union_type=UnionType([str_type, none_type])) def test_simple(): assert _are_types_castable(str_type, str_type) assert not _are_types_castable(str_type, int_type) assert not _are_types_castable(int_type, str_type) def test_metadata(): a = LiteralType(simple=SimpleType.STRING, metadata={"test": 456}) assert _are_types_castable( a, LiteralType(simple=SimpleType.STRING, metadata={"test": 123}), ) # must not clobber metadata assert a.metadata == {"test": 456}
def test_map(): assert _are_types_castable(LiteralType(map_value_type=str_type), LiteralType(map_value_type=str_type)) assert not _are_types_castable(LiteralType(map_value_type=str_type), LiteralType(map_value_type=int_type))
def get_literal_type(self, t: Type[MyDataset]) -> LiteralType: """ This is useful to tell the Flytekit type system that ``MyDataset`` actually refers to what corresponding type In this example, we say its of format binary (do not try to introspect) and there are more than one files in it """ return LiteralType(blob=self._TYPE_INFO)
def test_enum(): e = LiteralType(enum_type=EnumType(["a", "b"])) # enum is a str assert _are_types_castable(e, str_type) # a str is not necessarily an enum assert not _are_types_castable(str_type, e)
def get_literal_type(self, t: Type[FlyteSchema]) -> LiteralType: return LiteralType(schema=self._get_schema_type(t))
def test_collection(): assert _are_types_castable(LiteralType(collection_type=str_type), LiteralType(collection_type=str_type)) assert not _are_types_castable(LiteralType(collection_type=str_type), LiteralType(collection_type=int_type))
def get_literal_type(self, t: Type[pyspark.sql.DataFrame]) -> LiteralType: return LiteralType(schema=self._get_schema_type())
def get_literal_type(self, t: Type[ScikitLearn2ONNX]) -> LiteralType: return LiteralType( blob=BlobType(format=self.ONNX_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE))