def test_wf_container_task_multiple(): square = ContainerTask( name="square", input_data_dir="/var/inputs", output_data_dir="/var/outputs", inputs=kwtypes(val=int), outputs=kwtypes(out=int), image="alpine", command=["sh", "-c", "echo $(( {{.Inputs.val}} * {{.Inputs.val}} )) | tee /var/outputs/out"], ) sum = ContainerTask( name="sum", input_data_dir="/var/flyte/inputs", output_data_dir="/var/flyte/outputs", inputs=kwtypes(x=int, y=int), outputs=kwtypes(out=int), image="alpine", command=["sh", "-c", "echo $(( {{.Inputs.x}} + {{.Inputs.y}} )) | tee /var/flyte/outputs/out"], ) @workflow def raw_container_wf(val1: int, val2: int) -> int: return sum(x=square(val=val1), y=square(val=val2)) with task_mock(square) as square_mock, task_mock(sum) as sum_mock: square_mock.side_effect = lambda val: val * val assert square(val=10) == 100 sum_mock.side_effect = lambda x, y: x + y assert sum(x=10, y=10) == 20 assert raw_container_wf(val1=10, val2=10) == 200
def test_wf1_with_sql(): sql = SQLTask( "my-query", query_template= "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", inputs=kwtypes(ds=datetime.datetime), outputs=kwtypes(results=FlyteSchema), metadata=TaskMetadata(retries=2), ) @task def t1() -> datetime.datetime: return datetime.datetime.now() @workflow def my_wf() -> FlyteSchema: dt = t1() return sql(ds=dt) with task_mock(sql) as mock: mock.return_value = pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] }) assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all()
def test_wf_container_task(): @task def t1(a: int) -> (int, str): return a + 2, str(a) + "-HELLO" t2 = ContainerTask( "raw", image="alpine", inputs=kwtypes(a=int, b=str), input_data_dir="/tmp", output_data_dir="/tmp", command=["cat"], arguments=["/tmp/a"], ) @workflow def wf(a: int): x, y = t1(a=a) t2(a=x, b=y) with task_mock(t2) as mock: mock.side_effect = lambda a, b: None assert t2(a=10, b="hello") is None wf(a=10)
def test_sql_task(): sql = SQLTask( "my-query", query_template= "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", inputs=kwtypes(ds=datetime.datetime), outputs=kwtypes(results=FlyteSchema), metadata=TaskMetadata(retries=2, cache=True, cache_version="0.1"), ) @task(cache=True, cache_version="0.1.2") def t1() -> datetime.datetime: global n_cached_task_calls n_cached_task_calls += 1 return datetime.datetime.now() @workflow def my_wf() -> FlyteSchema: dt = t1() return sql(ds=dt) with task_mock(sql) as mock: mock.return_value = pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] }) assert n_cached_task_calls == 0 assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert n_cached_task_calls == 1 # The second and third calls hit the cache assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert n_cached_task_calls == 1 assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert n_cached_task_calls == 1
def test_ref_task_more(): @reference_task( project="flytesnacks", domain="development", name="recipes.aaa.simple.join_strings", version="553018f39e519bdb2597b652639c30ce16b99c79", ) def ref_t1(a: typing.List[str]) -> str: ... @workflow def wf1(in1: typing.List[str]) -> str: return ref_t1(a=in1) with pytest.raises(Exception) as e: wf1(in1=["hello", "world"]) assert "You must mock this out" in f"{e}" with task_mock(ref_t1) as mock: mock.return_value = "hello" assert wf1(in1=["hello", "world"]) == "hello"