def test_fill_in_literal_type(): class TempEncoder(StructuredDatasetEncoder): def __init__(self, fmt: str): super().__init__(MyDF, "tmpfs://", supported_format=fmt) def encode( self, ctx: FlyteContext, structured_dataset: StructuredDataset, structured_dataset_type: StructuredDatasetType, ) -> literals.StructuredDataset: return literals.StructuredDataset(uri="") StructuredDatasetTransformerEngine.register(TempEncoder("myavro"), default_for_type=True) lt = TypeEngine.to_literal_type(MyDF) assert lt.structured_dataset_type.format == "myavro" ctx = FlyteContextManager.current_context() fdt = StructuredDatasetTransformerEngine() sd = StructuredDataset(dataframe=42) l = fdt.to_literal(ctx, sd, MyDF, lt) # Test that the literal type is filled in even though the encode function above doesn't do it. assert l.scalar.structured_dataset.metadata.structured_dataset_type.format == "myavro" # Test that looking up encoders/decoders falls back to the "" encoder/decoder empty_format_temp_encoder = TempEncoder("") StructuredDatasetTransformerEngine.register(empty_format_temp_encoder, default_for_type=False) res = StructuredDatasetTransformerEngine.get_encoder( MyDF, "tmpfs", "rando") assert res is empty_format_temp_encoder
def get_subset_df( df: Annotated[StructuredDataset, subset_cols] ) -> Annotated[StructuredDataset, subset_cols]: df = df.open(pd.DataFrame).all() df = pd.concat([df, pd.DataFrame([[30]], columns=["Age"])]) # On specifying BigQuery uri for StructuredDataset, Flytekit will write pd.dataframe to a BigQuery table return StructuredDataset(dataframe=df)
def consume(df: subset_schema) -> subset_schema: df = df.open(pl.DataFrame).all() assert df["col2"][0] == "a" assert df["col2"][1] == "b" assert df["col2"][2] == "c" return StructuredDataset(dataframe=df)
def wf(): df = generate_pandas() np_array = generate_numpy() arrow_df = generate_arrow() t1(dataframe=df) t1a(dataframe=df) t2(dataframe=df) t3(dataset=StructuredDataset(uri=PANDAS_PATH)) t3a(dataset=StructuredDataset(uri=PANDAS_PATH)) t4(dataset=StructuredDataset(uri=PANDAS_PATH)) t5(dataframe=df) t6(dataset=StructuredDataset(uri=BQ_PATH)) t7(df1=df, df2=df) t8(dataframe=arrow_df) t8a(dataframe=arrow_df) t9(dataframe=np_array) t10(dataset=StructuredDataset(uri=NUMPY_PATH))
def consume(df: full_schema) -> full_schema: df = df.open(pl.DataFrame).all() assert df["col1"][0] == 1 assert df["col1"][1] == 3 assert df["col1"][2] == 2 assert df["col2"][0] == "a" assert df["col2"][1] == "b" assert df["col2"][2] == "c" return StructuredDataset(dataframe=df.sort("col1"))
def test_pandas(): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) encoder = basic_dfs.PandasToParquetEncodingHandler("/") decoder = basic_dfs.ParquetToPandasDecodingHandler("/") ctx = context_manager.FlyteContextManager.current_context() sd = StructuredDataset(dataframe=df) sd_type = StructuredDatasetType(format="parquet") sd_lit = encoder.encode(ctx, sd, sd_type) df2 = decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type)) assert df.equals(df2)
def test_to_literal(): ctx = FlyteContextManager.current_context() lt = TypeEngine.to_literal_type(pd.DataFrame) df = generate_pandas() fdt = StructuredDatasetTransformerEngine() lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt) assert lit.scalar.structured_dataset.metadata.structured_dataset_type.format == PARQUET assert lit.scalar.structured_dataset.metadata.structured_dataset_type.format == PARQUET sd_with_literal_and_df = StructuredDataset(df) sd_with_literal_and_df._literal_sd = lit with pytest.raises(ValueError, match="Shouldn't have specified both literal"): fdt.to_literal(ctx, sd_with_literal_and_df, python_type=StructuredDataset, expected=lt) sd_with_nothing = StructuredDataset() with pytest.raises(ValueError, match="If dataframe is not specified"): fdt.to_literal(ctx, sd_with_nothing, python_type=StructuredDataset, expected=lt) sd_with_uri = StructuredDataset(uri="s3://some/extant/df.parquet") lt = TypeEngine.to_literal_type(Annotated[StructuredDataset, {}, "new-df-format"]) lit = fdt.to_literal(ctx, sd_with_uri, python_type=StructuredDataset, expected=lt) assert lit.scalar.structured_dataset.uri == "s3://some/extant/df.parquet" assert lit.scalar.structured_dataset.metadata.structured_dataset_type.format == "new-df-format"
def generate() -> full_schema: df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) return StructuredDataset(dataframe=df)
def t9(dataframe: np.ndarray) -> Annotated[StructuredDataset, my_cols]: # numpy -> Arrow table -> s3 (parquet) return StructuredDataset(dataframe=dataframe, uri=NUMPY_PATH)
def t7( df1: pd.DataFrame, df2: pd.DataFrame ) -> (Annotated[StructuredDataset, my_cols], Annotated[StructuredDataset, my_cols]): # df1: pandas -> bq # df2: pandas -> s3 (parquet) return StructuredDataset(dataframe=df1, uri=BQ_PATH), StructuredDataset(dataframe=df2)
def t5(dataframe: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]: # s3 (parquet) -> pandas -> bq return StructuredDataset(dataframe=dataframe, uri=BQ_PATH)
def t1a(dataframe: pd.DataFrame) -> Annotated[StructuredDataset, my_cols, PARQUET]: # S3 (parquet) -> Pandas -> S3 (parquet) return StructuredDataset(dataframe=dataframe, uri=PANDAS_PATH)
def to_numpy( ds: Annotated[StructuredDataset, subset_cols] ) -> Annotated[StructuredDataset, subset_cols, PARQUET]: numpy_array = ds.open(np.ndarray).all() return StructuredDataset(dataframe=numpy_array)
def t1() -> Annotated[StructuredDataset, "avro"]: return StructuredDataset(dataframe=df)
def test_sd(): sd = StructuredDataset(dataframe="hi") sd.uri = "my uri" assert sd.file_format == PARQUET with pytest.raises(ValueError, match="No dataframe type set"): sd.all() with pytest.raises(ValueError, match="No dataframe type set."): sd.iter() class MockPandasDecodingHandlers(StructuredDatasetDecoder): def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, ) -> typing.Union[typing.Generator[pd.DataFrame, None, None]]: yield pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) StructuredDatasetTransformerEngine.register(MockPandasDecodingHandlers( pd.DataFrame, "tmpfs"), default_for_type=False) sd = StructuredDataset() sd._literal_sd = literals.StructuredDataset( uri="tmpfs://somewhere", metadata=StructuredDatasetMetadata(StructuredDatasetType(format=""))) assert isinstance(sd.open(pd.DataFrame).iter(), typing.Generator) with pytest.raises(ValueError): sd.open(pd.DataFrame).all() class MockPandasDecodingHandlers(StructuredDatasetDecoder): def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, ) -> pd.DataFrame: return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) StructuredDatasetTransformerEngine.register(MockPandasDecodingHandlers( pd.DataFrame, "tmpfs"), default_for_type=False, override=True) sd = StructuredDataset() sd._literal_sd = literals.StructuredDataset( uri="tmpfs://somewhere", metadata=StructuredDatasetMetadata(StructuredDatasetType(format=""))) with pytest.raises(ValueError): sd.open(pd.DataFrame).iter()
def get_subset_df( df: Annotated[pd.DataFrame, superset_cols] ) -> Annotated[StructuredDataset, subset_cols]: df = pd.concat([df, pd.DataFrame([[30]], columns=["age"])]) return StructuredDataset(dataframe=df)
def show_sd(in_sd: StructuredDataset): pd.set_option("expand_frame_repr", False) df = in_sd.open(pd.DataFrame).all() print(df)
def t8(dataframe: pa.Table) -> Annotated[StructuredDataset, my_cols]: # Arrow table -> s3 (parquet) print(dataframe.columns) return StructuredDataset(dataframe=dataframe)