def test_wf1_with_sql_with_patch():
    sql = SQLTask(
        "my-query",
        query_template=
        "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10",
        inputs=kwtypes(ds=datetime.datetime),
        outputs=kwtypes(results=FlyteSchema),
        metadata=TaskMetadata(retries=2),
    )

    @task
    def t1() -> datetime.datetime:
        return datetime.datetime.now()

    @workflow
    def my_wf() -> FlyteSchema:
        dt = t1()
        return sql(ds=dt)

    @patch(sql)
    def test_user_demo_test(mock_sql):
        mock_sql.return_value = pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()

    # Have to call because tests inside tests don't run
    test_user_demo_test()
Beispiel #2
0
def test_wf1_with_sql():
    sql = SQLTask(
        "my-query",
        query_template=
        "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10",
        inputs=kwtypes(ds=datetime.datetime),
        outputs=kwtypes(results=FlyteSchema),
        metadata=TaskMetadata(retries=2),
    )

    @task
    def t1() -> datetime.datetime:
        return datetime.datetime.now()

    @workflow
    def my_wf() -> FlyteSchema:
        dt = t1()
        return sql(ds=dt)

    with task_mock(sql) as mock:
        mock.return_value = pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()
    assert context_manager.FlyteContextManager.size() == 1
    def dummy_node(node_id) -> Node:
        n = Node(
            node_id,
            metadata=None,
            bindings=[],
            upstream_nodes=[],
            flyte_entity=SQLTask(name="x", query_template="x", inputs={}),
        )

        n._id = node_id
        return n
Beispiel #4
0
def test_sql_task():
    sql = SQLTask(
        "my-query",
        query_template=
        "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10",
        inputs=kwtypes(ds=datetime.datetime),
        outputs=kwtypes(results=FlyteSchema),
        metadata=TaskMetadata(retries=2, cache=True, cache_version="0.1"),
    )

    @task(cache=True, cache_version="0.1.2")
    def t1() -> datetime.datetime:
        global n_cached_task_calls
        n_cached_task_calls += 1
        return datetime.datetime.now()

    @workflow
    def my_wf() -> FlyteSchema:
        dt = t1()
        return sql(ds=dt)

    with task_mock(sql) as mock:
        mock.return_value = pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })
        assert n_cached_task_calls == 0
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()
        assert n_cached_task_calls == 1
        # The second and third calls hit the cache
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()
        assert n_cached_task_calls == 1
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()
        assert n_cached_task_calls == 1
"""

import datetime

import pandas
from flytekit import SQLTask, TaskMetadata, kwtypes, task, workflow
from flytekit.testing import patch, task_mock
from flytekit.types.schema import FlyteSchema

# %%
# This is a generic SQL task (and is by default not hooked up to any datastore nor handled by any plugin), and must
# be mocked.
sql = SQLTask(
    "my-query",
    query_template=
    "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10",
    inputs=kwtypes(ds=datetime.datetime),
    outputs=kwtypes(results=FlyteSchema),
    metadata=TaskMetadata(retries=2),
)


# %%
# This is a task that can run locally
@task
def t1() -> datetime.datetime:
    return datetime.datetime.now()


# %%
# Declare a workflow that chains these two tasks together.
@workflow