Пример #1
0
def test_two(two_sample_inputs):
    my_input = two_sample_inputs[0]
    my_input_2 = two_sample_inputs[1]

    @dynamic
    def dt1(a: List[MyInput]) -> List[FlyteFile]:
        x = []
        for aa in a:
            x.append(aa.main_product)
        return x

    with FlyteContextManager.with_context(
        FlyteContextManager.current_context().with_serialization_settings(
            SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
    ) as ctx:
        with FlyteContextManager.with_context(
            ctx.with_execution_state(
                ctx.execution_state.with_params(
                    mode=ExecutionState.Mode.TASK_EXECUTION,
                )
            )
        ) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(
                ctx, d={"a": [my_input, my_input_2]}, type_hints={"a": List[MyInput]}
            )
            dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
            assert len(dynamic_job_spec.literals["o0"].collection.literals) == 2
Пример #2
0
def test_dispatch_execute_system_error(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto):
    # Just leave these here, mock them out so nothing happens
    mock_get_data.return_value = True
    mock_upload_dir.return_value = True

    ctx = context_manager.FlyteContext.current_context()
    with context_manager.FlyteContextManager.with_context(
        ctx.with_execution_state(
            ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
        )
    ) as ctx:
        input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5})
        mock_load_proto.return_value = input_literal_map.to_flyte_idl()

        python_task = mock.MagicMock()
        python_task.dispatch_execute.side_effect = Exception("some system exception")

        files = OrderedDict()
        mock_write_to_file.side_effect = get_output_collector(files)
        # See comment in test_dispatch_execute_ignore for why we need to decorate
        system_entry_point(_dispatch_execute)(ctx, python_task, "inputs path", "outputs prefix")
        assert len(files) == 1

        # Exception should've caused an error file
        k = list(files.keys())[0]
        assert "error.pb" in k

        v = list(files.values())[0]
        ed = error_models.ErrorDocument.from_flyte_idl(v)
        # System errors default to recoverable
        assert ed.error.kind == error_models.ContainerError.Kind.RECOVERABLE
        assert "some system exception" in ed.error.message
        assert ed.error.origin == execution_models.ExecutionError.ErrorKind.SYSTEM
Пример #3
0
def test_dispatch_execute_user_error_non_recov(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto):
    # Just leave these here, mock them out so nothing happens
    mock_get_data.return_value = True
    mock_upload_dir.return_value = True

    @task
    def t1(a: int) -> str:
        # Should be interpreted as a non-recoverable user error
        raise ValueError(f"some exception {a}")

    ctx = context_manager.FlyteContext.current_context()
    with context_manager.FlyteContextManager.with_context(
        ctx.with_execution_state(
            ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
        )
    ) as ctx:
        input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5})
        mock_load_proto.return_value = input_literal_map.to_flyte_idl()

        files = OrderedDict()
        mock_write_to_file.side_effect = get_output_collector(files)
        # See comment in test_dispatch_execute_ignore for why we need to decorate
        system_entry_point(_dispatch_execute)(ctx, t1, "inputs path", "outputs prefix")
        assert len(files) == 1

        # Exception should've caused an error file
        k = list(files.keys())[0]
        assert "error.pb" in k

        v = list(files.values())[0]
        ed = error_models.ErrorDocument.from_flyte_idl(v)
        assert ed.error.kind == error_models.ContainerError.Kind.NON_RECOVERABLE
        assert "some exception 5" in ed.error.message
        assert ed.error.origin == execution_models.ExecutionError.ErrorKind.USER
Пример #4
0
def test_dispatch_execute_normal(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto):
    # Just leave these here, mock them out so nothing happens
    mock_get_data.return_value = True
    mock_upload_dir.return_value = True

    @task
    def t1(a: int) -> str:
        return f"string is: {a}"

    ctx = context_manager.FlyteContext.current_context()
    with context_manager.FlyteContextManager.with_context(
        ctx.with_execution_state(
            ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
        )
    ) as ctx:
        input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5})
        mock_load_proto.return_value = input_literal_map.to_flyte_idl()

        files = OrderedDict()
        mock_write_to_file.side_effect = get_output_collector(files)
        # See comment in test_dispatch_execute_ignore for why we need to decorate
        system_entry_point(_dispatch_execute)(ctx, t1, "inputs path", "outputs prefix")
        assert len(files) == 1

        # A successful run should've written an outputs file.
        k = list(files.keys())[0]
        assert "outputs.pb" in k

        v = list(files.values())[0]
        lm = _literal_models.LiteralMap.from_flyte_idl(v)
        assert lm.literals["o0"].scalar.primitive.string_value == "string is: 5"
Пример #5
0
def test_pb_guess_python_type():
    artifact_tag = catalog_pb2.CatalogArtifactTag(artifact_id="artifact_1", name="artifact_name")

    x = {"a": artifact_tag}
    lt = TypeEngine.to_literal_type(catalog_pb2.CatalogArtifactTag)
    gt = TypeEngine.guess_python_type(lt)
    assert gt == catalog_pb2.CatalogArtifactTag
    ctx = FlyteContextManager.current_context()
    lm = TypeEngine.dict_to_literal_map(ctx, x, {"a": gt})
    pv = TypeEngine.to_python_value(ctx, lm.literals["a"], gt)
    assert pv == artifact_tag
Пример #6
0
def test_dc_dyn_directory(folders_and_files_setup):
    proxy_c = MyProxyConfiguration(splat_data_dir="/tmp/proxy_splat", apriori_file="/opt/config/a_file")
    proxy_p = MyProxyParameters(id="pp_id", job_i_step=1)

    my_input_gcs = MyInput(
        main_product=FlyteFile(folders_and_files_setup[0]),
        apriori_config=MyAprioriConfiguration(
            static_data_dir=FlyteDirectory("gs://my-bucket/one"),
            external_data_dir=FlyteDirectory("gs://my-bucket/two"),
        ),
        proxy_config=proxy_c,
        proxy_params=proxy_p,
    )

    my_input_gcs_2 = MyInput(
        main_product=FlyteFile(folders_and_files_setup[0]),
        apriori_config=MyAprioriConfiguration(
            static_data_dir=FlyteDirectory("gs://my-bucket/three"),
            external_data_dir=FlyteDirectory("gs://my-bucket/four"),
        ),
        proxy_config=proxy_c,
        proxy_params=proxy_p,
    )

    @dynamic
    def dt1(a: List[MyInput]) -> List[FlyteDirectory]:
        x = []
        for aa in a:
            x.append(aa.apriori_config.external_data_dir)

        return x

    ctx = FlyteContextManager.current_context()
    cb = (
        ctx.new_builder()
        .with_serialization_settings(
            SerializationSettings(
                project="test_proj",
                domain="test_domain",
                version="abc",
                image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
                env={},
            )
        )
        .with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION))
    )
    with FlyteContextManager.with_context(cb) as ctx:
        input_literal_map = TypeEngine.dict_to_literal_map(
            ctx, d={"a": [my_input_gcs, my_input_gcs_2]}, type_hints={"a": List[MyInput]}
        )
        dynamic_job_spec = dt1.dispatch_execute(ctx, input_literal_map)
        assert dynamic_job_spec.literals["o0"].collection.literals[0].scalar.blob.uri == "gs://my-bucket/two"
        assert dynamic_job_spec.literals["o0"].collection.literals[1].scalar.blob.uri == "gs://my-bucket/four"
Пример #7
0
def test_wf1_with_fast_dynamic():
    @task
    def t1(a: int) -> str:
        a = a + 2
        return "fast-" + str(a)

    @dynamic
    def my_subwf(a: int) -> typing.List[str]:
        s = []
        for i in range(a):
            s.append(t1(a=i))
        return s

    @workflow
    def my_wf(a: int) -> typing.List[str]:
        v = my_subwf(a=a)
        return v

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_serialization_settings(
                flytekit.configuration.SerializationSettings(
                    project="test_proj",
                    domain="test_domain",
                    version="abc",
                    image_config=ImageConfig(
                        Image(name="name", fqn="image", tag="name")),
                    env={},
                    fast_serialization_settings=FastSerializationSettings(
                        enabled=True,
                        destination_dir="/User/flyte/workflows",
                        distribution_location="s3://my-s3-bucket/fast/123",
                    ),
                ))) as ctx:
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 5})
            dynamic_job_spec = my_subwf.dispatch_execute(
                ctx, input_literal_map)
            assert len(dynamic_job_spec._nodes) == 5
            assert len(dynamic_job_spec.tasks) == 1
            args = " ".join(dynamic_job_spec.tasks[0].container.args)
            assert args.startswith(
                "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 "
                "--dest-dir /User/flyte/workflows")

    assert context_manager.FlyteContextManager.size() == 1
Пример #8
0
def test_dynamic():
    @dynamic
    def my_subwf(a: int) -> typing.List[int]:
        s = []
        for i in range(a):
            s.append(ft(a=i))
        return s

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_serialization_settings(
                context_manager.SerializationSettings(
                    project="test_proj",
                    domain="test_domain",
                    version="abc",
                    image_config=ImageConfig(
                        Image(name="name", fqn="image", tag="name")),
                    env={},
                    fast_serialization_settings=FastSerializationSettings(
                        enabled=True),
                ))) as ctx:
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": 2})
            # Test that it works
            dynamic_job_spec = my_subwf.dispatch_execute(
                ctx, input_literal_map)
            assert len(dynamic_job_spec._nodes) == 2
            assert len(dynamic_job_spec.tasks) == 1
            assert dynamic_job_spec.tasks[0].id == ft.id

            # Test that the fast execute stuff does not get applied because the commands of tasks fetched from
            # Admin should never change.
            args = " ".join(dynamic_job_spec.tasks[0].container.args)
            assert not args.startswith("pyflyte-fast-execute")
Пример #9
0
def test_fast():
    REQUESTS_GPU = Resources(cpu="123m",
                             mem="234Mi",
                             ephemeral_storage="123M",
                             gpu="1")
    LIMITS_GPU = Resources(cpu="124M",
                           mem="235Mi",
                           ephemeral_storage="124M",
                           gpu="1")

    def get_minimal_pod_task_config() -> Pod:
        primary_container = V1Container(name="flytetask")
        pod_spec = V1PodSpec(containers=[primary_container])
        return Pod(pod_spec=pod_spec, primary_container_name="flytetask")

    @task(
        task_config=get_minimal_pod_task_config(),
        requests=REQUESTS_GPU,
        limits=LIMITS_GPU,
    )
    def pod_task_with_resources(dummy_input: str) -> str:
        return dummy_input

    @dynamic(requests=REQUESTS_GPU, limits=LIMITS_GPU)
    def dynamic_task_with_pod_subtask(dummy_input: str) -> str:
        pod_task_with_resources(dummy_input=dummy_input)
        return dummy_input

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env={"FOO": "baz"},
        image_config=ImageConfig(default_image=default_img,
                                 images=[default_img]),
        fast_serialization_settings=FastSerializationSettings(
            enabled=True,
            destination_dir="/User/flyte/workflows",
            distribution_location="s3://my-s3-bucket/fast/123",
        ),
    )

    with context_manager.FlyteContextManager.with_context(
            context_manager.FlyteContextManager.current_context(
            ).with_serialization_settings(serialization_settings)) as ctx:
        with context_manager.FlyteContextManager.with_context(
                ctx.with_execution_state(
                    ctx.execution_state.with_params(
                        mode=ExecutionState.Mode.TASK_EXECUTION, ))) as ctx:
            input_literal_map = TypeEngine.dict_to_literal_map(
                ctx, {"dummy_input": "hi"})
            dynamic_job_spec = dynamic_task_with_pod_subtask.dispatch_execute(
                ctx, input_literal_map)
            # print(dynamic_job_spec)
            assert len(dynamic_job_spec._nodes) == 1
            assert len(dynamic_job_spec.tasks) == 1
            args = " ".join(
                dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0]
                ["args"])
            assert args.startswith(
                "pyflyte-fast-execute --additional-distribution s3://my-s3-bucket/fast/123 "
                "--dest-dir /User/flyte/workflows")
            assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][
                "resources"]["limits"]["cpu"] == "124M"
            assert dynamic_job_spec.tasks[0].k8s_pod.pod_spec["containers"][0][
                "resources"]["requests"]["gpu"] == "1"

    assert context_manager.FlyteContextManager.size() == 1