def test_wf1_with_sql_with_patch(): sql = SQLTask( "my-query", query_template= "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", inputs=kwtypes(ds=datetime.datetime), outputs=kwtypes(results=FlyteSchema), metadata=TaskMetadata(retries=2), ) @task def t1() -> datetime.datetime: return datetime.datetime.now() @workflow def my_wf() -> FlyteSchema: dt = t1() return sql(ds=dt) @patch(sql) def test_user_demo_test(mock_sql): mock_sql.return_value = pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] }) assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() # Have to call because tests inside tests don't run test_user_demo_test()
def test_wf1_with_sql(): sql = SQLTask( "my-query", query_template= "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", inputs=kwtypes(ds=datetime.datetime), outputs=kwtypes(results=FlyteSchema), metadata=TaskMetadata(retries=2), ) @task def t1() -> datetime.datetime: return datetime.datetime.now() @workflow def my_wf() -> FlyteSchema: dt = t1() return sql(ds=dt) with task_mock(sql) as mock: mock.return_value = pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] }) assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert context_manager.FlyteContextManager.size() == 1
def dummy_node(node_id) -> Node: n = Node( node_id, metadata=None, bindings=[], upstream_nodes=[], flyte_entity=SQLTask(name="x", query_template="x", inputs={}), ) n._id = node_id return n
def test_sql_task(): sql = SQLTask( "my-query", query_template= "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", inputs=kwtypes(ds=datetime.datetime), outputs=kwtypes(results=FlyteSchema), metadata=TaskMetadata(retries=2, cache=True, cache_version="0.1"), ) @task(cache=True, cache_version="0.1.2") def t1() -> datetime.datetime: global n_cached_task_calls n_cached_task_calls += 1 return datetime.datetime.now() @workflow def my_wf() -> FlyteSchema: dt = t1() return sql(ds=dt) with task_mock(sql) as mock: mock.return_value = pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] }) assert n_cached_task_calls == 0 assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert n_cached_task_calls == 1 # The second and third calls hit the cache assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert n_cached_task_calls == 1 assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert n_cached_task_calls == 1
""" import datetime import pandas from flytekit import SQLTask, TaskMetadata, kwtypes, task, workflow from flytekit.testing import patch, task_mock from flytekit.types.schema import FlyteSchema # %% # This is a generic SQL task (and is by default not hooked up to any datastore nor handled by any plugin), and must # be mocked. sql = SQLTask( "my-query", query_template= "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", inputs=kwtypes(ds=datetime.datetime), outputs=kwtypes(results=FlyteSchema), metadata=TaskMetadata(retries=2), ) # %% # This is a task that can run locally @task def t1() -> datetime.datetime: return datetime.datetime.now() # %% # Declare a workflow that chains these two tasks together. @workflow