예제 #1
0
def test_wf_container_task_multiple():
    square = ContainerTask(
        name="square",
        input_data_dir="/var/inputs",
        output_data_dir="/var/outputs",
        inputs=kwtypes(val=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=["sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"],
    )

    sum = ContainerTask(
        name="sum",
        input_data_dir="/var/flyte/inputs",
        output_data_dir="/var/flyte/outputs",
        inputs=kwtypes(x=int, y=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=["sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out"],
    )

    @workflow
    def raw_container_wf(val1: int, val2: int) -> int:
        return sum(x=square(val=val1), y=square(val=val2))

    with task_mock(square) as square_mock, task_mock(sum) as sum_mock:
        square_mock.side_effect = lambda val: val * val
        assert square(val=10) == 100

        sum_mock.side_effect = lambda x, y: x + y
        assert sum(x=10, y=10) == 20

        assert raw_container_wf(val1=10, val2=10) == 200
예제 #2
0
def test_wf1_with_sql():
    sql = SQLTask(
        "my-query",
        query_template=
        "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10",
        inputs=kwtypes(ds=datetime.datetime),
        outputs=kwtypes(results=FlyteSchema),
        metadata=TaskMetadata(retries=2),
    )

    @task
    def t1() -> datetime.datetime:
        return datetime.datetime.now()

    @workflow
    def my_wf() -> FlyteSchema:
        dt = t1()
        return sql(ds=dt)

    with task_mock(sql) as mock:
        mock.return_value = pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()
    assert context_manager.FlyteContextManager.size() == 1
예제 #3
0
def test_to_python_value_with_incoming_columns():
    # make a literal with a type that has two columns
    original_type = Annotated[pd.DataFrame, kwtypes(name=str, age=int)]
    ctx = FlyteContextManager.current_context()
    lt = TypeEngine.to_literal_type(original_type)
    df = generate_pandas()
    fdt = StructuredDatasetTransformerEngine()
    lit = fdt.to_literal(ctx, df, python_type=original_type, expected=lt)
    assert len(lit.scalar.structured_dataset.metadata.structured_dataset_type.
               columns) == 2

    # declare a new type that only has one column
    # get the dataframe, make sure it has the column that was asked for.
    subset_sd_type = Annotated[StructuredDataset, kwtypes(age=int)]
    sd = fdt.to_python_value(ctx, lit, subset_sd_type)
    assert sd.metadata.structured_dataset_type.columns[0].name == "age"
    sub_df = sd.open(pd.DataFrame).all()
    assert sub_df.shape[1] == 1

    # check when columns are not specified, should pull both and add column information.
    sd = fdt.to_python_value(ctx, lit, StructuredDataset)
    assert len(sd.metadata.structured_dataset_type.columns) == 2

    # should also work if subset type is just an annotated pd.DataFrame
    subset_pd_type = Annotated[pd.DataFrame, kwtypes(age=int)]
    sub_df = fdt.to_python_value(ctx, lit, subset_pd_type)
    assert sub_df.shape[1] == 1
예제 #4
0
def test_to_python_value_without_incoming_columns():
    # make a literal with a type with no columns
    ctx = FlyteContextManager.current_context()
    lt = TypeEngine.to_literal_type(pd.DataFrame)
    df = generate_pandas()
    fdt = StructuredDatasetTransformerEngine()
    lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt)
    assert len(lit.scalar.structured_dataset.metadata.structured_dataset_type.
               columns) == 0

    # declare a new type that only has one column
    # get the dataframe, make sure it has the column that was asked for.
    subset_sd_type = Annotated[StructuredDataset, kwtypes(age=int)]
    sd = fdt.to_python_value(ctx, lit, subset_sd_type)
    assert sd.metadata.structured_dataset_type.columns[0].name == "age"
    sub_df = sd.open(pd.DataFrame).all()
    assert sub_df.shape[1] == 1

    # check when columns are not specified, should pull both and add column information.
    # todo: see the todos in the open_as, and iter_as functions in StructuredDatasetTransformerEngine
    #  we have to recreate the literal because the test case above filled in the metadata
    lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt)
    sd = fdt.to_python_value(ctx, lit, StructuredDataset)
    assert sd.metadata.structured_dataset_type.columns == []
    sub_df = sd.open(pd.DataFrame).all()
    assert sub_df.shape[1] == 2

    # should also work if subset type is just an annotated pd.DataFrame
    lit = fdt.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt)
    subset_pd_type = Annotated[pd.DataFrame, kwtypes(age=int)]
    sub_df = fdt.to_python_value(ctx, lit, subset_pd_type)
    assert sub_df.shape[1] == 1
예제 #5
0
def test_wf1_with_sql_with_patch():
    sql = SQLTask(
        "my-query",
        query_template=
        "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10",
        inputs=kwtypes(ds=datetime.datetime),
        outputs=kwtypes(results=FlyteSchema),
        metadata=TaskMetadata(retries=2),
    )

    @task
    def t1() -> datetime.datetime:
        return datetime.datetime.now()

    @workflow
    def my_wf() -> FlyteSchema:
        dt = t1()
        return sql(ds=dt)

    @patch(sql)
    def test_user_demo_test(mock_sql):
        mock_sql.return_value = pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()

    # Have to call because tests inside tests don't run
    test_user_demo_test()
예제 #6
0
def test_task_serialization():
    sql_task = SQLite3Task(
        "test",
        query_template=
        "select TrackId, Name from tracks limit {{.inputs.limit}}",
        inputs=kwtypes(limit=int),
        output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)],
        task_config=SQLite3Config(
            uri=EXAMPLE_DB,
            compressed=True,
        ),
    )

    tt = sql_task.serialize_to_model(sql_task.SERIALIZE_SETTINGS)

    assert tt.container.args == [
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_customized_container_task.default_task_template_resolver",
        "--",
        "{{.taskTemplatePath}}",
        "flytekit.extras.sqlite3.task.SQLite3TaskExecutor",
    ]

    assert tt.custom[
        "query_template"] == "select TrackId, Name from tracks limit {{.inputs.limit}}"
    assert tt.container.image != ""
예제 #7
0
def test_ge_flytefile_multiple_args():
    task_object_one = GreatExpectationsTask(
        name="test13",
        datasource_name="data",
        inputs=kwtypes(dataset=FlyteFile),
        expectation_suite_name="test.demo",
        data_connector_name="data_flytetype_data_connector",
        local_file_path="/tmp",
    )
    task_object_two = GreatExpectationsTask(
        name="test14",
        datasource_name="data",
        inputs=kwtypes(dataset=FlyteFile),
        expectation_suite_name="test1.demo",
        data_connector_name="data_flytetype_data_connector",
        local_file_path="/tmp",
    )

    @task
    def get_file_name(dataset_one: FlyteFile, dataset_two: FlyteFile) -> typing.Tuple[int, int]:
        df_one = pd.read_csv(os.path.join("data", dataset_one))
        df_two = pd.read_csv(os.path.join("data", dataset_two))
        return len(df_one), len(df_two)

    @workflow
    def wf(
        dataset_one: FlyteFile = "https://raw.githubusercontent.com/superconductive/ge_tutorials/main/data/yellow_tripdata_sample_2019-01.csv",
        dataset_two: FlyteFile = "https://raw.githubusercontent.com/superconductive/ge_tutorials/main/data/yellow_tripdata_sample_2019-02.csv",
    ) -> typing.Tuple[int, int]:
        task_object_one(dataset=dataset_one)
        task_object_two(dataset=dataset_two)
        return get_file_name(dataset_one=dataset_one, dataset_two=dataset_two)

    assert wf() == (10000, 10000)
예제 #8
0
def test_resolver_load_task():
    # any task is fine, just copied one
    square = ContainerTask(
        name="square",
        input_data_dir="/var/inputs",
        output_data_dir="/var/outputs",
        inputs=kwtypes(val=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"
        ],
    )

    resolver = TaskTemplateResolver()
    ts = get_serializable(OrderedDict(), serialization_settings, square)
    file = tempfile.NamedTemporaryFile().name
    # load_task should create an instance of the path to the object given, doesn't need to be a real executor
    write_proto_to_file(ts.template.to_flyte_idl(), file)
    shim_task = resolver.load_task(
        [file, f"{Placeholder.__module__}.Placeholder"])
    assert isinstance(shim_task.executor, Placeholder)
    assert shim_task.task_template.id.name == "square"
    assert shim_task.task_template.interface.inputs["val"] is not None
    assert shim_task.task_template.interface.outputs["out"] is not None
예제 #9
0
def test_execute_sqlite3_task(flyteclient, flyte_workflows_register,
                              flyte_remote_env):
    remote = FlyteRemote(Config.auto(), PROJECT, "development")

    example_db = "https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip"
    interactive_sql_task = SQLite3Task(
        "basic_querying",
        query_template=
        "select TrackId, Name from tracks limit {{.inputs.limit}}",
        inputs=kwtypes(limit=int),
        output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)],
        task_config=SQLite3Config(
            uri=example_db,
            compressed=True,
        ),
    )
    registered_sql_task = remote.register(interactive_sql_task)
    execution = remote.execute(registered_sql_task,
                               inputs={"limit": 10},
                               wait=True)
    output = execution.outputs["results"]
    result = output.open().all()
    assert result.__class__.__name__ == "DataFrame"
    assert "TrackId" in result
    assert "Name" in result
예제 #10
0
def test_task_serialization_deserialization_with_secret(sql_server):
    secret_group = "foo"
    secret_name = "bar"

    sec = SecretsManager()
    os.environ[sec.get_secrets_env_var(secret_group,
                                       secret_name)] = "IMMEDIATE"

    sql_task = SQLAlchemyTask(
        "test",
        query_template="select 1;",
        inputs=kwtypes(limit=int),
        output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)],
        task_config=SQLAlchemyConfig(
            uri=sql_server,
            # As sqlite3 doesn't really support passwords, we pass another connect_arg as a secret
            secret_connect_args={
                "isolation_level": Secret(group=secret_group, key=secret_name)
            },
        ),
    )

    tt = sql_task.serialize_to_model(sql_task.SERIALIZE_SETTINGS)

    assert tt.container is not None

    assert tt.container.args == [
        "pyflyte-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_customized_container_task.default_task_template_resolver",
        "--",
        "{{.taskTemplatePath}}",
        "flytekitplugins.sqlalchemy.task.SQLAlchemyTaskExecutor",
    ]

    assert tt.custom["query_template"] == "select 1;"
    assert tt.container.image != ""

    assert "secret_connect_args" in tt.custom
    assert "isolation_level" in tt.custom["secret_connect_args"]
    assert tt.custom["secret_connect_args"]["isolation_level"][
        "group"] == secret_group
    assert tt.custom["secret_connect_args"]["isolation_level"][
        "key"] == secret_name
    assert tt.custom["secret_connect_args"]["isolation_level"][
        "group_version"] is None
    assert tt.custom["secret_connect_args"]["isolation_level"][
        "mount_requirement"] == 0

    executor = SQLAlchemyTaskExecutor()
    r = executor.execute_from_model(tt)

    assert r.iat[0, 0] == 1
예제 #11
0
def query_wf() -> int:
    df = SQLite3Task(
        name="cookbook.sqlite3.sample_inline",
        query_template=
        "select TrackId, Name from tracks limit {{.inputs.limit}}",
        inputs=kwtypes(limit=int),
        output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)],
        task_config=SQLite3Config(uri=EXAMPLE_DB, compressed=True),
    )(limit=100)
    return print_and_count_columns(df=df)
예제 #12
0
def test_task_schema():
    sql_task = SQLite3Task(
        "test",
        query_template="select TrackId, Name from tracks limit {{.inputs.limit}}",
        inputs=kwtypes(limit=int),
        output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)],
        task_config=SQLite3Config(uri=EXAMPLE_DB, compressed=True,),
    )

    assert sql_task.output_columns is not None
    df = sql_task(limit=1)
    assert df is not None
예제 #13
0
def test_task_schema(sql_server):
    sql_task = SQLAlchemyTask(
        "test",
        query_template=
        "select TrackId, Name from tracks limit {{.inputs.limit}}",
        inputs=kwtypes(limit=int),
        output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)],
        task_config=SQLAlchemyConfig(uri=sql_server, ),
    )

    assert sql_task.output_columns is not None
    df = sql_task(limit=1)
    assert df is not None
예제 #14
0
def test_serialization():
    square = ContainerTask(
        name="square",
        input_data_dir="/var/inputs",
        output_data_dir="/var/outputs",
        inputs=kwtypes(val=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"
        ],
    )

    sum = ContainerTask(
        name="sum",
        input_data_dir="/var/flyte/inputs",
        output_data_dir="/var/flyte/outputs",
        inputs=kwtypes(x=int, y=int),
        outputs=kwtypes(out=int),
        image="alpine",
        command=[
            "sh", "-c",
            "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out"
        ],
    )

    @workflow
    def raw_container_wf(val1: int, val2: int) -> int:
        return sum(x=square(val=val1), y=square(val=val2))

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    wf_spec = get_serializable(OrderedDict(), serialization_settings,
                               raw_container_wf)
    assert wf_spec is not None
    assert wf_spec.template is not None
    assert len(wf_spec.template.nodes) == 3
    sqn_spec = get_serializable(OrderedDict(), serialization_settings, square)
    assert sqn_spec.template.container.image == "alpine"
    sumn_spec = get_serializable(OrderedDict(), serialization_settings, sum)
    assert sumn_spec.template.container.image == "alpine"
예제 #15
0
def test_wf_typed_schema():
    schema1 = FlyteSchema[kwtypes(x=int, y=str)]

    @task
    def t1() -> schema1:
        s = schema1()
        s.open().write(pandas.DataFrame(data={"x": [1, 2], "y": ["3", "4"]}))
        return s

    @task
    def t2(s: FlyteSchema[kwtypes(x=int, y=str)]) -> FlyteSchema[kwtypes(x=int)]:
        df = s.open().all()
        return df[s.column_names()[:-1]]

    @workflow
    def wf() -> FlyteSchema[kwtypes(x=int)]:
        return t2(s=t1())

    w = t1()
    assert w is not None
    df = w.open(override_mode=SchemaOpenMode.READ).all()
    result_df = df.reset_index(drop=True) == pandas.DataFrame(data={"x": [1, 2], "y": ["3", "4"]}).reset_index(
        drop=True
    )
    assert result_df.all().all()

    df = t2(s=w.as_readonly())
    assert df is not None
    result_df = df.reset_index(drop=True) == pandas.DataFrame(data={"x": [1, 2]}).reset_index(drop=True)
    assert result_df.all().all()

    x = wf()
    df = x.open().all()
    result_df = df.reset_index(drop=True) == pandas.DataFrame(data={"x": [1, 2]}).reset_index(drop=True)
    assert result_df.all().all()
예제 #16
0
def test_wf_container_task():
    @task
    def t1(a: int) -> (int, str):
        return a + 2, str(a) + "-HELLO"

    t2 = ContainerTask(
        "raw",
        image="alpine",
        inputs=kwtypes(a=int, b=str),
        input_data_dir="/tmp",
        output_data_dir="/tmp",
        command=["cat"],
        arguments=["/tmp/a"],
    )

    @workflow
    def wf(a: int):
        x, y = t1(a=a)
        t2(a=x, b=y)

    with task_mock(t2) as mock:
        mock.side_effect = lambda a, b: None
        assert t2(a=10, b="hello") is None

        wf(a=10)
예제 #17
0
def test_input_output_substitution_files():
    script = "cat {inputs.f} > {outputs.y}"
    if os.name == "nt":
        script = script.replace("cat", "type")
    t = ShellTask(
        name="test",
        debug=True,
        script=script,
        inputs=kwtypes(f=CSVFile),
        output_locs=[
            OutputLocation(var="y", var_type=FlyteFile, location="{inputs.f}.mod"),
        ],
    )

    assert t.script == script

    contents = "1,2,3,4\n"
    with tempfile.TemporaryDirectory() as tmp:
        csv = os.path.join(tmp, "abc.csv")
        print(csv)
        with open(csv, "w") as f:
            f.write(contents)
        y = t(f=csv)
        assert y.path[-4:] == ".mod"
        assert os.path.exists(y.path)
        with open(y.path) as f:
            s = f.read()
        assert s == contents
예제 #18
0
def test_wf_schema_to_df():
    schema1 = FlyteSchema[kwtypes(x=int, y=str)]

    @task(cache=True, cache_version="v0")
    def t1() -> schema1:
        global n_cached_task_calls
        n_cached_task_calls += 1

        s = schema1()
        s.open().write(pandas.DataFrame(data={"x": [1, 2], "y": ["3", "4"]}))
        return s

    @task(cache=True, cache_version="v1")
    def t2(df: pandas.DataFrame) -> int:
        global n_cached_task_calls
        n_cached_task_calls += 1

        return len(df.columns.values)

    @workflow
    def wf() -> int:
        return t2(df=t1())

    assert n_cached_task_calls == 0
    x = wf()
    assert x == 2
    assert n_cached_task_calls == 2
    # Second call does not bump the counter
    x = wf()
    assert x == 2
    assert n_cached_task_calls == 2
예제 #19
0
def test_fs_sd_compatibility():
    my_schema = FlyteSchema[kwtypes(name=str, age=int)]

    @task
    def my_dataset() -> pd.DataFrame:
        return pd.DataFrame(data={"name": ["Alice"], "age": [5]})

    @task(task_config=Spark())
    def my_spark(df: pyspark.sql.DataFrame) -> my_schema:
        session = flytekit.current_context().spark_session
        new_df = session.createDataFrame([("Bob", 10)],
                                         my_schema.column_names())
        return df.union(new_df)

    @task(task_config=Spark())
    def read_spark_df(df: pyspark.sql.DataFrame) -> int:
        return df.count()

    @workflow
    def my_wf() -> int:
        df = my_dataset()
        fs = my_spark(df=df)
        return read_spark_df(df=fs)

    res = my_wf()
    assert res == 2
예제 #20
0
def test_workflow(sql_server):
    @task
    def my_task(df: pandas.DataFrame) -> int:
        return len(df[df.columns[0]])

    insert_task = SQLAlchemyTask(
        "test",
        query_template="insert into tracks values (5, 'flyte')",
        output_schema_type=None,
        task_config=SQLAlchemyConfig(uri=sql_server),
    )

    sql_task = SQLAlchemyTask(
        "test",
        query_template="select * from tracks limit {{.inputs.limit}}",
        inputs=kwtypes(limit=int),
        task_config=SQLAlchemyConfig(uri=sql_server),
    )

    @workflow
    def wf(limit: int) -> int:
        insert_task()
        return my_task(df=sql_task(limit=limit))

    assert wf(limit=10) == 6
예제 #21
0
def test_ge_simple_task():
    task_object = GreatExpectationsTask(
        name="test1",
        datasource_name="data",
        inputs=kwtypes(dataset=str),
        expectation_suite_name="test.demo",
        data_connector_name="data_example_data_connector",
    )

    # valid data
    result = task_object(dataset="yellow_tripdata_sample_2019-01.csv")

    assert result["success"] is True
    assert result["statistics"]["evaluated_expectations"] == result["statistics"]["successful_expectations"]

    # invalid data
    with pytest.raises(ValidationError):
        invalid_result = task_object(dataset="yellow_tripdata_sample_2019-02.csv")
        assert invalid_result["success"] is False
        assert (
            invalid_result["statistics"]["evaluated_expectations"]
            != invalid_result["statistics"]["successful_expectations"]
        )

    assert task_object.python_interface.inputs == {"dataset": str}
예제 #22
0
def test_notebook_task_simple():
    nb_name = "nb-spark"
    nb = NotebookTask(
        name="test",
        notebook_path=_get_nb_path(nb_name, abs=False),
        outputs=kwtypes(df=FlyteSchema[kwtypes(name=str, age=int)]),
        task_config=Spark(spark_conf={"x": "y"}),
    )
    n, out, render = nb.execute()
    assert nb.python_interface.outputs.keys() == {
        "df", "out_nb", "out_rendered_nb"
    }
    assert nb.output_notebook_path == out == _get_nb_path(nb_name,
                                                          suffix="-out")
    assert nb.rendered_output_path == render == _get_nb_path(
        nb_name, suffix="-out", ext=".html")
예제 #23
0
def test_raw_shell_task_instantiation(capfd):
    if script_sh_2 is None:
        return
    pst = RawShellTask(
        name="test",
        debug=True,
        inputs=flytekit.kwtypes(env=typing.Dict[str, str], script_args=str, script_file=str),
        output_locs=[
            OutputLocation(
                var="out",
                var_type=FlyteDirectory,
                location="{ctx.working_directory}",
            )
        ],
        script="""
#!/bin/bash

set -uex

cd {ctx.working_directory}

{inputs.export_env}

bash {inputs.script_file} {inputs.script_args}
""",
    )
    pst(script_file=script_sh_2, script_args="first_arg second_arg", env={})
    cap = capfd.readouterr()
    assert "first_arg" in cap.out
    assert "second_arg" in cap.out
예제 #24
0
파일: shell.py 프로젝트: flyteorg/flytekit
def get_raw_shell_task(name: str) -> RawShellTask:

    return RawShellTask(
        name=name,
        debug=True,
        inputs=flytekit.kwtypes(env=typing.Dict[str, str],
                                script_args=str,
                                script_file=str),
        output_locs=[
            OutputLocation(
                var="out",
                var_type=FlyteDirectory,
                location="{ctx.working_directory}",
            )
        ],
        script="""
#!/bin/bash

set -uex

cd {ctx.working_directory}

{inputs.export_env}

bash {inputs.script_file} {inputs.script_args}
""",
    )
예제 #25
0
def test_ge_with_task():
    task_object = GreatExpectationsTask(
        name="test6",
        datasource_name="data",
        inputs=kwtypes(dataset=str),
        expectation_suite_name="test.demo",
        data_connector_name="data_example_data_connector",
    )

    @task
    def my_task(csv_file: str) -> int:
        df = pd.read_csv(os.path.join("data", csv_file))
        return df.shape[0]

    @workflow
    def valid_wf(dataset: str = "yellow_tripdata_sample_2019-01.csv") -> int:
        task_object(dataset=dataset)
        return my_task(csv_file=dataset)

    @workflow
    def invalid_wf(dataset: str = "yellow_tripdata_sample_2019-02.csv") -> int:
        task_object(dataset=dataset)
        return my_task(csv_file=dataset)

    valid_result = valid_wf()
    assert valid_result == 10000

    with pytest.raises(ValidationError, match=r".*passenger_count -> expect_column_min_to_be_between.*"):
        invalid_wf()
예제 #26
0
def test_bad_conversion():
    orig = FlyteSchema[kwtypes(my_custom=bool)]
    lt = TypeEngine.to_literal_type(orig)
    # Make a not real column type
    lt.schema.columns[0]._type = 15
    with pytest.raises(ValueError):
        TypeEngine.guess_python_type(lt)
예제 #27
0
 def __init__(
     self,
     name: str,
     query_template: str,
     inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
     task_config: typing.Optional[SQLite3Config] = None,
     output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None,
     **kwargs,
 ):
     if task_config is None or task_config.uri is None:
         raise ValueError("SQLite DB uri is required.")
     outputs = kwtypes(
         results=output_schema_type if output_schema_type else FlyteSchema)
     super().__init__(
         name=name,
         task_config=task_config,
         # If you make changes to this task itself, you'll have to bump this image to what the release _will_ be.
         container_image="ghcr.io/flyteorg/flytekit:py38-v0.19.0b7",
         executor_type=SQLite3TaskExecutor,
         task_type=self._SQLITE_TASK_TYPE,
         query_template=query_template,
         inputs=inputs,
         outputs=outputs,
         **kwargs,
     )
예제 #28
0
파일: task.py 프로젝트: flyteorg/flytekit
    def __init__(
        self,
        name: str,
        query_template: str,
        task_config: SQLAlchemyConfig,
        inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
        output_schema_type: typing.Optional[
            typing.Type[FlyteSchema]] = FlyteSchema,
        container_image: str = SQLAlchemyDefaultImages.default_image(),
        **kwargs,
    ):
        if output_schema_type:
            outputs = kwtypes(results=output_schema_type)
        else:
            outputs = None

        super().__init__(
            name=name,
            task_config=task_config,
            executor_type=SQLAlchemyTaskExecutor,
            task_type=self._SQLALCHEMY_TASK_TYPE,
            query_template=query_template,
            container_image=container_image,
            inputs=inputs,
            outputs=outputs,
            **kwargs,
        )
예제 #29
0
def test_notebook_task_simple():
    nb_name = "nb-simple"
    nb = NotebookTask(
        name="test",
        notebook_path=_get_nb_path(nb_name, abs=False),
        inputs=kwtypes(pi=float),
        outputs=kwtypes(square=float),
    )
    sqr, out, render = nb.execute(pi=4)
    assert sqr == 16.0
    assert nb.python_interface.inputs == {"pi": float}
    assert nb.python_interface.outputs.keys() == {
        "square", "out_nb", "out_rendered_nb"
    }
    assert nb.output_notebook_path == out == _get_nb_path(nb_name,
                                                          suffix="-out")
    assert nb.rendered_output_path == render == _get_nb_path(
        nb_name, suffix="-out", ext=".html")
예제 #30
0
def test_sql_task():
    sql = SQLTask(
        "my-query",
        query_template=
        "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10",
        inputs=kwtypes(ds=datetime.datetime),
        outputs=kwtypes(results=FlyteSchema),
        metadata=TaskMetadata(retries=2, cache=True, cache_version="0.1"),
    )

    @task(cache=True, cache_version="0.1.2")
    def t1() -> datetime.datetime:
        global n_cached_task_calls
        n_cached_task_calls += 1
        return datetime.datetime.now()

    @workflow
    def my_wf() -> FlyteSchema:
        dt = t1()
        return sql(ds=dt)

    with task_mock(sql) as mock:
        mock.return_value = pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })
        assert n_cached_task_calls == 0
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()
        assert n_cached_task_calls == 1
        # The second and third calls hit the cache
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()
        assert n_cached_task_calls == 1
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()
        assert n_cached_task_calls == 1