Ejemplo n.º 1
0
def test_map_task_types():
    strs = map_task(t1, metadata=TaskMetadata(retries=1))(a=[5, 6])
    assert strs == ["7", "8"]

    with pytest.raises(TypeError):
        _ = map_task(t1, metadata=TaskMetadata(retries=1))(a=1)

    with pytest.raises(TypeError):
        _ = map_task(t1,
                     metadata=TaskMetadata(retries=1))(a=["invalid", "args"])
Ejemplo n.º 2
0
def test_serialization():
    maptask = map_task(t1, metadata=TaskMetadata(retries=1))
    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 maptask)

    assert task_spec.template.type == "container_array"
    assert task_spec.template.task_type_version == 1
    assert task_spec.template.container.args == [
        "pyflyte-map-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "test_map_task",
        "task-name",
        "t1",
    ]
Ejemplo n.º 3
0
def test_wf1_with_sql_with_patch():
    sql = SQLTask(
        "my-query",
        query_template=
        "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10",
        inputs=kwtypes(ds=datetime.datetime),
        outputs=kwtypes(results=FlyteSchema),
        metadata=TaskMetadata(retries=2),
    )

    @task
    def t1() -> datetime.datetime:
        return datetime.datetime.now()

    @workflow
    def my_wf() -> FlyteSchema:
        dt = t1()
        return sql(ds=dt)

    @patch(sql)
    def test_user_demo_test(mock_sql):
        mock_sql.return_value = pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()

    # Have to call because tests inside tests don't run
    test_user_demo_test()
Ejemplo n.º 4
0
def test_wf1_with_sql():
    sql = SQLTask(
        "my-query",
        query_template=
        "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10",
        inputs=kwtypes(ds=datetime.datetime),
        outputs=kwtypes(results=FlyteSchema),
        metadata=TaskMetadata(retries=2),
    )

    @task
    def t1() -> datetime.datetime:
        return datetime.datetime.now()

    @workflow
    def my_wf() -> FlyteSchema:
        dt = t1()
        return sql(ds=dt)

    with task_mock(sql) as mock:
        mock.return_value = pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()
    assert context_manager.FlyteContextManager.size() == 1
Ejemplo n.º 5
0
def test_serialization(serialization_settings):
    maptask = map_task(t1, metadata=TaskMetadata(retries=1))
    task_spec = get_serializable(OrderedDict(), serialization_settings,
                                 maptask)

    # By default all map_task tasks will have their custom fields set.
    assert task_spec.template.custom["minSuccessRatio"] == 1.0
    assert task_spec.template.type == "container_array"
    assert task_spec.template.task_type_version == 1
    assert task_spec.template.container.args == [
        "pyflyte-map-execute",
        "--inputs",
        "{{.input}}",
        "--output-prefix",
        "{{.outputPrefix}}",
        "--raw-output-data-prefix",
        "{{.rawOutputDataPrefix}}",
        "--checkpoint-path",
        "{{.checkpointOutputPrefix}}",
        "--prev-checkpoint",
        "{{.prevCheckpointPrefix}}",
        "--resolver",
        "flytekit.core.python_auto_container.default_task_resolver",
        "--",
        "task-module",
        "tests.flytekit.unit.core.test_map_task",
        "task-name",
        "t1",
    ]
Ejemplo n.º 6
0
 def my_wf(x: typing.List[int]) -> typing.List[str]:
     return map_task(
         my_mappable_task,
         metadata=TaskMetadata(retries=1),
         concurrency=10,
         min_success_ratio=0.75,
     )(a=x).with_overrides(cpu="10M")
Ejemplo n.º 7
0
def test_serialization_workflow_def():
    @task
    def complex_task(a: int) -> str:
        b = a + 2
        return str(b)

    maptask = map_task(complex_task, metadata=TaskMetadata(retries=1))

    @workflow
    def w1(a: typing.List[int]) -> typing.List[str]:
        return maptask(a=a)

    @workflow
    def w2(a: typing.List[int]) -> typing.List[str]:
        return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a)

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = context_manager.SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
    )
    serialized_control_plane_entities = OrderedDict()
    wf1_spec = get_serializable(serialized_control_plane_entities,
                                serialization_settings, w1)
    assert wf1_spec.template is not None
    assert len(wf1_spec.template.nodes) == 1

    wf2_spec = get_serializable(serialized_control_plane_entities,
                                serialization_settings, w2)
    assert wf2_spec.template is not None
    assert len(wf2_spec.template.nodes) == 1

    flyte_entities = list(serialized_control_plane_entities.keys())

    tasks_seen = []
    for entity in flyte_entities:
        if isinstance(entity, MapPythonTask) and "complex" in entity.name:
            tasks_seen.append(entity)

    assert len(tasks_seen) == 2
    print(tasks_seen[0])
Ejemplo n.º 8
0
def test_sql_task():
    sql = SQLTask(
        "my-query",
        query_template=
        "SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ .Inputs.ds }}' LIMIT 10",
        inputs=kwtypes(ds=datetime.datetime),
        outputs=kwtypes(results=FlyteSchema),
        metadata=TaskMetadata(retries=2, cache=True, cache_version="0.1"),
    )

    @task(cache=True, cache_version="0.1.2")
    def t1() -> datetime.datetime:
        global n_cached_task_calls
        n_cached_task_calls += 1
        return datetime.datetime.now()

    @workflow
    def my_wf() -> FlyteSchema:
        dt = t1()
        return sql(ds=dt)

    with task_mock(sql) as mock:
        mock.return_value = pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })
        assert n_cached_task_calls == 0
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()
        assert n_cached_task_calls == 1
        # The second and third calls hit the cache
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()
        assert n_cached_task_calls == 1
        assert (my_wf().open().all() == pandas.DataFrame(data={
            "x": [1, 2],
            "y": ["3", "4"]
        })).all().all()
        assert n_cached_task_calls == 1
Ejemplo n.º 9
0
def test_serialization_workflow_def(serialization_settings):
    @task
    def complex_task(a: int) -> str:
        b = a + 2
        return str(b)

    maptask = map_task(complex_task, metadata=TaskMetadata(retries=1))

    @workflow
    def w1(a: typing.List[int]) -> typing.List[str]:
        return maptask(a=a)

    @workflow
    def w2(a: typing.List[int]) -> typing.List[str]:
        return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a)

    serialized_control_plane_entities = OrderedDict()
    wf1_spec = get_serializable(serialized_control_plane_entities,
                                serialization_settings, w1)
    assert wf1_spec.template is not None
    assert len(wf1_spec.template.nodes) == 1

    wf2_spec = get_serializable(serialized_control_plane_entities,
                                serialization_settings, w2)
    assert wf2_spec.template is not None
    assert len(wf2_spec.template.nodes) == 1

    flyte_entities = list(serialized_control_plane_entities.keys())

    tasks_seen = []
    for entity in flyte_entities:
        if isinstance(entity, MapPythonTask) and "complex" in entity.name:
            tasks_seen.append(entity)

    assert len(tasks_seen) == 2
    print(tasks_seen[0])
Ejemplo n.º 10
0
 def w2(a: typing.List[int]) -> typing.List[str]:
     return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a)
Ejemplo n.º 11
0
 def my_wf(a: typing.List[int]) -> int:
     x = map_task(t1, metadata=TaskMetadata(retries=1))(a=a)
     return t2(x=x)
Ejemplo n.º 12
0
 def my_wf(a: typing.List[int]) -> (int, str):
     x, y = maptask(t1, metadata=TaskMetadata(retries=1))(a=a)
     return t2(a=x, b=y)
Ejemplo n.º 13
0
def test_map_task_metadata():
    map_meta = TaskMetadata(retries=1)
    mapped_1 = map_task(t2, metadata=map_meta)
    assert mapped_1.metadata is map_meta
    mapped_2 = map_task(t2)
    assert mapped_2.metadata is t2.metadata