예제 #1
0
def test_fill_in_literal_type():
    class TempEncoder(StructuredDatasetEncoder):
        def __init__(self, fmt: str):
            super().__init__(MyDF, "tmpfs://", supported_format=fmt)

        def encode(
            self,
            ctx: FlyteContext,
            structured_dataset: StructuredDataset,
            structured_dataset_type: StructuredDatasetType,
        ) -> literals.StructuredDataset:
            return literals.StructuredDataset(uri="")

    StructuredDatasetTransformerEngine.register(TempEncoder("myavro"),
                                                default_for_type=True)
    lt = TypeEngine.to_literal_type(MyDF)
    assert lt.structured_dataset_type.format == "myavro"

    ctx = FlyteContextManager.current_context()
    fdt = StructuredDatasetTransformerEngine()
    sd = StructuredDataset(dataframe=42)
    l = fdt.to_literal(ctx, sd, MyDF, lt)
    # Test that the literal type is filled in even though the encode function above doesn't do it.
    assert l.scalar.structured_dataset.metadata.structured_dataset_type.format == "myavro"

    # Test that looking up encoders/decoders falls back to the "" encoder/decoder
    empty_format_temp_encoder = TempEncoder("")
    StructuredDatasetTransformerEngine.register(empty_format_temp_encoder,
                                                default_for_type=False)

    res = StructuredDatasetTransformerEngine.get_encoder(
        MyDF, "tmpfs", "rando")
    assert res is empty_format_temp_encoder
예제 #2
0
def get_subset_df(
    df: Annotated[StructuredDataset, subset_cols]
) -> Annotated[StructuredDataset, subset_cols]:
    df = df.open(pd.DataFrame).all()
    df = pd.concat([df, pd.DataFrame([[30]], columns=["Age"])])
    # On specifying BigQuery uri for StructuredDataset, Flytekit will write pd.dataframe to a BigQuery table
    return StructuredDataset(dataframe=df)
예제 #3
0
    def consume(df: subset_schema) -> subset_schema:
        df = df.open(pl.DataFrame).all()

        assert df["col2"][0] == "a"
        assert df["col2"][1] == "b"
        assert df["col2"][2] == "c"

        return StructuredDataset(dataframe=df)
def wf():
    df = generate_pandas()
    np_array = generate_numpy()
    arrow_df = generate_arrow()
    t1(dataframe=df)
    t1a(dataframe=df)
    t2(dataframe=df)
    t3(dataset=StructuredDataset(uri=PANDAS_PATH))
    t3a(dataset=StructuredDataset(uri=PANDAS_PATH))
    t4(dataset=StructuredDataset(uri=PANDAS_PATH))
    t5(dataframe=df)
    t6(dataset=StructuredDataset(uri=BQ_PATH))
    t7(df1=df, df2=df)
    t8(dataframe=arrow_df)
    t8a(dataframe=arrow_df)
    t9(dataframe=np_array)
    t10(dataset=StructuredDataset(uri=NUMPY_PATH))
예제 #5
0
    def consume(df: full_schema) -> full_schema:
        df = df.open(pl.DataFrame).all()

        assert df["col1"][0] == 1
        assert df["col1"][1] == 3
        assert df["col1"][2] == 2
        assert df["col2"][0] == "a"
        assert df["col2"][1] == "b"
        assert df["col2"][2] == "c"

        return StructuredDataset(dataframe=df.sort("col1"))
def test_pandas():
    df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
    encoder = basic_dfs.PandasToParquetEncodingHandler("/")
    decoder = basic_dfs.ParquetToPandasDecodingHandler("/")

    ctx = context_manager.FlyteContextManager.current_context()
    sd = StructuredDataset(dataframe=df)
    sd_type = StructuredDatasetType(format="parquet")
    sd_lit = encoder.encode(ctx, sd, sd_type)

    df2 = decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type))
    assert df.equals(df2)
예제 #7
0
def test_to_literal():
    ctx = FlyteContextManager.current_context()
    lt = TypeEngine.to_literal_type(pd.DataFrame)
    df = generate_pandas()

    fdt = StructuredDatasetTransformerEngine()

    lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt)
    assert lit.scalar.structured_dataset.metadata.structured_dataset_type.format == PARQUET
    assert lit.scalar.structured_dataset.metadata.structured_dataset_type.format == PARQUET

    sd_with_literal_and_df = StructuredDataset(df)
    sd_with_literal_and_df._literal_sd = lit

    with pytest.raises(ValueError,
                       match="Shouldn't have specified both literal"):
        fdt.to_literal(ctx,
                       sd_with_literal_and_df,
                       python_type=StructuredDataset,
                       expected=lt)

    sd_with_nothing = StructuredDataset()
    with pytest.raises(ValueError, match="If dataframe is not specified"):
        fdt.to_literal(ctx,
                       sd_with_nothing,
                       python_type=StructuredDataset,
                       expected=lt)

    sd_with_uri = StructuredDataset(uri="s3://some/extant/df.parquet")

    lt = TypeEngine.to_literal_type(Annotated[StructuredDataset, {},
                                              "new-df-format"])
    lit = fdt.to_literal(ctx,
                         sd_with_uri,
                         python_type=StructuredDataset,
                         expected=lt)
    assert lit.scalar.structured_dataset.uri == "s3://some/extant/df.parquet"
    assert lit.scalar.structured_dataset.metadata.structured_dataset_type.format == "new-df-format"
예제 #8
0
 def generate() -> full_schema:
     df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")})
     return StructuredDataset(dataframe=df)
def t9(dataframe: np.ndarray) -> Annotated[StructuredDataset, my_cols]:
    # numpy -> Arrow table -> s3 (parquet)
    return StructuredDataset(dataframe=dataframe, uri=NUMPY_PATH)
def t7(
    df1: pd.DataFrame, df2: pd.DataFrame
) -> (Annotated[StructuredDataset, my_cols], Annotated[StructuredDataset, my_cols]):
    # df1: pandas -> bq
    # df2: pandas -> s3 (parquet)
    return StructuredDataset(dataframe=df1, uri=BQ_PATH), StructuredDataset(dataframe=df2)
def t5(dataframe: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]:
    # s3 (parquet) -> pandas -> bq
    return StructuredDataset(dataframe=dataframe, uri=BQ_PATH)
def t1a(dataframe: pd.DataFrame) -> Annotated[StructuredDataset, my_cols, PARQUET]:
    # S3 (parquet) -> Pandas -> S3 (parquet)
    return StructuredDataset(dataframe=dataframe, uri=PANDAS_PATH)
예제 #13
0
def to_numpy(
    ds: Annotated[StructuredDataset, subset_cols]
) -> Annotated[StructuredDataset, subset_cols, PARQUET]:
    numpy_array = ds.open(np.ndarray).all()
    return StructuredDataset(dataframe=numpy_array)
예제 #14
0
 def t1() -> Annotated[StructuredDataset, "avro"]:
     return StructuredDataset(dataframe=df)
예제 #15
0
def test_sd():
    sd = StructuredDataset(dataframe="hi")
    sd.uri = "my uri"
    assert sd.file_format == PARQUET

    with pytest.raises(ValueError, match="No dataframe type set"):
        sd.all()

    with pytest.raises(ValueError, match="No dataframe type set."):
        sd.iter()

    class MockPandasDecodingHandlers(StructuredDatasetDecoder):
        def decode(
            self,
            ctx: FlyteContext,
            flyte_value: literals.StructuredDataset,
            current_task_metadata: StructuredDatasetMetadata,
        ) -> typing.Union[typing.Generator[pd.DataFrame, None, None]]:
            yield pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

    StructuredDatasetTransformerEngine.register(MockPandasDecodingHandlers(
        pd.DataFrame, "tmpfs"),
                                                default_for_type=False)
    sd = StructuredDataset()
    sd._literal_sd = literals.StructuredDataset(
        uri="tmpfs://somewhere",
        metadata=StructuredDatasetMetadata(StructuredDatasetType(format="")))
    assert isinstance(sd.open(pd.DataFrame).iter(), typing.Generator)

    with pytest.raises(ValueError):
        sd.open(pd.DataFrame).all()

    class MockPandasDecodingHandlers(StructuredDatasetDecoder):
        def decode(
            self,
            ctx: FlyteContext,
            flyte_value: literals.StructuredDataset,
            current_task_metadata: StructuredDatasetMetadata,
        ) -> pd.DataFrame:
            return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

    StructuredDatasetTransformerEngine.register(MockPandasDecodingHandlers(
        pd.DataFrame, "tmpfs"),
                                                default_for_type=False,
                                                override=True)
    sd = StructuredDataset()
    sd._literal_sd = literals.StructuredDataset(
        uri="tmpfs://somewhere",
        metadata=StructuredDatasetMetadata(StructuredDatasetType(format="")))

    with pytest.raises(ValueError):
        sd.open(pd.DataFrame).iter()
예제 #16
0
def get_subset_df(
    df: Annotated[pd.DataFrame, superset_cols]
) -> Annotated[StructuredDataset, subset_cols]:
    df = pd.concat([df, pd.DataFrame([[30]], columns=["age"])])
    return StructuredDataset(dataframe=df)
예제 #17
0
def show_sd(in_sd: StructuredDataset):
    pd.set_option("expand_frame_repr", False)
    df = in_sd.open(pd.DataFrame).all()
    print(df)
def t8(dataframe: pa.Table) -> Annotated[StructuredDataset, my_cols]:
    # Arrow table -> s3 (parquet)
    print(dataframe.columns)
    return StructuredDataset(dataframe=dataframe)