def test_map_task_types(): strs = map_task(t1, metadata=TaskMetadata(retries=1))(a=[5, 6]) assert strs == ["7", "8"] with pytest.raises(TypeError): _ = map_task(t1, metadata=TaskMetadata(retries=1))(a=1) with pytest.raises(TypeError): _ = map_task(t1, metadata=TaskMetadata(retries=1))(a=["invalid", "args"])
def test_serialization(): maptask = map_task(t1, metadata=TaskMetadata(retries=1)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) assert task_spec.template.type == "container_array" assert task_spec.template.task_type_version == 1 assert task_spec.template.container.args == [ "pyflyte-map-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "test_map_task", "task-name", "t1", ]
def test_wf1_with_sql_with_patch(): sql = SQLTask( "my-query", query_template= "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", inputs=kwtypes(ds=datetime.datetime), outputs=kwtypes(results=FlyteSchema), metadata=TaskMetadata(retries=2), ) @task def t1() -> datetime.datetime: return datetime.datetime.now() @workflow def my_wf() -> FlyteSchema: dt = t1() return sql(ds=dt) @patch(sql) def test_user_demo_test(mock_sql): mock_sql.return_value = pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] }) assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() # Have to call because tests inside tests don't run test_user_demo_test()
def test_wf1_with_sql(): sql = SQLTask( "my-query", query_template= "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", inputs=kwtypes(ds=datetime.datetime), outputs=kwtypes(results=FlyteSchema), metadata=TaskMetadata(retries=2), ) @task def t1() -> datetime.datetime: return datetime.datetime.now() @workflow def my_wf() -> FlyteSchema: dt = t1() return sql(ds=dt) with task_mock(sql) as mock: mock.return_value = pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] }) assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert context_manager.FlyteContextManager.size() == 1
def test_serialization(serialization_settings): maptask = map_task(t1, metadata=TaskMetadata(retries=1)) task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) # By default all map_task tasks will have their custom fields set. assert task_spec.template.custom["minSuccessRatio"] == 1.0 assert task_spec.template.type == "container_array" assert task_spec.template.task_type_version == 1 assert task_spec.template.container.args == [ "pyflyte-map-execute", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}", "--raw-output-data-prefix", "{{.rawOutputDataPrefix}}", "--checkpoint-path", "{{.checkpointOutputPrefix}}", "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", "flytekit.core.python_auto_container.default_task_resolver", "--", "task-module", "tests.flytekit.unit.core.test_map_task", "task-name", "t1", ]
def my_wf(x: typing.List[int]) -> typing.List[str]: return map_task( my_mappable_task, metadata=TaskMetadata(retries=1), concurrency=10, min_success_ratio=0.75, )(a=x).with_overrides(cpu="10M")
def test_serialization_workflow_def(): @task def complex_task(a: int) -> str: b = a + 2 return str(b) maptask = map_task(complex_task, metadata=TaskMetadata(retries=1)) @workflow def w1(a: typing.List[int]) -> typing.List[str]: return maptask(a=a) @workflow def w2(a: typing.List[int]) -> typing.List[str]: return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = context_manager.SerializationSettings( project="project", domain="domain", version="version", env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) serialized_control_plane_entities = OrderedDict() wf1_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w1) assert wf1_spec.template is not None assert len(wf1_spec.template.nodes) == 1 wf2_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w2) assert wf2_spec.template is not None assert len(wf2_spec.template.nodes) == 1 flyte_entities = list(serialized_control_plane_entities.keys()) tasks_seen = [] for entity in flyte_entities: if isinstance(entity, MapPythonTask) and "complex" in entity.name: tasks_seen.append(entity) assert len(tasks_seen) == 2 print(tasks_seen[0])
def test_sql_task(): sql = SQLTask( "my-query", query_template= "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10", inputs=kwtypes(ds=datetime.datetime), outputs=kwtypes(results=FlyteSchema), metadata=TaskMetadata(retries=2, cache=True, cache_version="0.1"), ) @task(cache=True, cache_version="0.1.2") def t1() -> datetime.datetime: global n_cached_task_calls n_cached_task_calls += 1 return datetime.datetime.now() @workflow def my_wf() -> FlyteSchema: dt = t1() return sql(ds=dt) with task_mock(sql) as mock: mock.return_value = pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] }) assert n_cached_task_calls == 0 assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert n_cached_task_calls == 1 # The second and third calls hit the cache assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert n_cached_task_calls == 1 assert (my_wf().open().all() == pandas.DataFrame(data={ "x": [1, 2], "y": ["3", "4"] })).all().all() assert n_cached_task_calls == 1
def test_serialization_workflow_def(serialization_settings): @task def complex_task(a: int) -> str: b = a + 2 return str(b) maptask = map_task(complex_task, metadata=TaskMetadata(retries=1)) @workflow def w1(a: typing.List[int]) -> typing.List[str]: return maptask(a=a) @workflow def w2(a: typing.List[int]) -> typing.List[str]: return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a) serialized_control_plane_entities = OrderedDict() wf1_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w1) assert wf1_spec.template is not None assert len(wf1_spec.template.nodes) == 1 wf2_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w2) assert wf2_spec.template is not None assert len(wf2_spec.template.nodes) == 1 flyte_entities = list(serialized_control_plane_entities.keys()) tasks_seen = [] for entity in flyte_entities: if isinstance(entity, MapPythonTask) and "complex" in entity.name: tasks_seen.append(entity) assert len(tasks_seen) == 2 print(tasks_seen[0])
def w2(a: typing.List[int]) -> typing.List[str]: return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a)
def my_wf(a: typing.List[int]) -> int: x = map_task(t1, metadata=TaskMetadata(retries=1))(a=a) return t2(x=x)
def my_wf(a: typing.List[int]) -> (int, str): x, y = maptask(t1, metadata=TaskMetadata(retries=1))(a=a) return t2(a=x, b=y)
def test_map_task_metadata(): map_meta = TaskMetadata(retries=1) mapped_1 = map_task(t2, metadata=map_meta) assert mapped_1.metadata is map_meta mapped_2 = map_task(t2) assert mapped_2.metadata is t2.metadata