예제 #1
0
def test_task_template_security_context(sec_ctx):
    obj = task.TaskTemplate(
        identifier.Identifier(identifier.ResourceType.TASK, "project",
                              "domain", "name", "version"),
        "python",
        parameterizers.LIST_OF_TASK_METADATA[0],
        parameterizers.LIST_OF_INTERFACES[0],
        {
            "a": 1,
            "b": {
                "c": 2,
                "d": 3
            }
        },
        container=task.Container(
            "my_image",
            ["this", "is", "a", "cmd"],
            ["this", "is", "an", "arg"],
            parameterizers.LIST_OF_RESOURCES[0],
            {"a": "b"},
            {"d": "e"},
        ),
        security_context=sec_ctx,
    )
    assert obj.security_context == sec_ctx
    expected = obj.security_context
    if sec_ctx:
        if sec_ctx.run_as is None and sec_ctx.secrets is None and sec_ctx.tokens is None:
            expected = None
    assert task.TaskTemplate.from_flyte_idl(
        obj.to_flyte_idl()).security_context == expected
예제 #2
0
def test_task_template(in_tuple):
    task_metadata, interfaces, resources = in_tuple
    obj = task.TaskTemplate(
        identifier.Identifier(identifier.ResourceType.TASK, "project",
                              "domain", "name", "version"),
        "python",
        task_metadata,
        interfaces, {
            'a': 1,
            'b': {
                'c': 2,
                'd': 3
            }
        },
        container=task.Container("my_image", ["this", "is", "a", "cmd"],
                                 ["this", "is", "an", "arg"], resources,
                                 {'a': 'b'}, {'d': 'e'}))
    assert obj.id.resource_type == identifier.ResourceType.TASK
    assert obj.id.project == "project"
    assert obj.id.domain == "domain"
    assert obj.id.name == "name"
    assert obj.id.version == "version"
    assert obj.type == "python"
    assert obj.metadata == task_metadata
    assert obj.interface == interfaces
    assert obj.custom == {'a': 1, 'b': {'c': 2, 'd': 3}}
    assert obj.container.image == "my_image"
    assert obj.container.resources == resources
    assert text_format.MessageToString(
        obj.to_flyte_idl()) == text_format.MessageToString(
            task.TaskTemplate.from_flyte_idl(
                obj.to_flyte_idl()).to_flyte_idl())
예제 #3
0
def get_serializable_task(
    entity_mapping: OrderedDict,
    settings: SerializationSettings,
    entity: FlyteLocalEntity,
) -> task_models.TaskSpec:
    task_id = _identifier_model.Identifier(
        _identifier_model.ResourceType.TASK,
        settings.project,
        settings.domain,
        entity.name,
        settings.version,
    )
    if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask):
        # For fast registration, we'll need to muck with the command, but only for certain kinds of tasks. Specifically,
        # tasks that rely on user code defined in the container. This should be encapsulated by the auto container
        # parent class
        entity.set_command_fn(_fast_serialize_command_fn(settings, entity))
    tt = task_models.TaskTemplate(
        id=task_id,
        type=entity.task_type,
        metadata=entity.metadata.to_taskmetadata_model(),
        interface=entity.interface,
        custom=entity.get_custom(settings),
        container=entity.get_container(settings),
        task_type_version=entity.task_type_version,
        security_context=entity.security_context,
        config=entity.get_config(settings),
        k8s_pod=entity.get_k8s_pod(settings),
    )
    if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask):
        entity.reset_command_fn()

    return task_models.TaskSpec(template=tt)
예제 #4
0
def test_task_template__k8s_pod_target():
    int_type = types.LiteralType(types.SimpleType.INTEGER)
    obj = task.TaskTemplate(
        identifier.Identifier(identifier.ResourceType.TASK, "project",
                              "domain", "name", "version"),
        "python",
        task.TaskMetadata(
            False,
            task.RuntimeMetadata(1, "v", "f"),
            timedelta(days=1),
            literal_models.RetryStrategy(5),
            False,
            "1.0",
            "deprecated",
            False,
        ),
        interface_models.TypedInterface(
            # inputs
            {"a": interface_models.Variable(int_type, "description1")},
            # outputs
            {
                "b": interface_models.Variable(int_type, "description2"),
                "c": interface_models.Variable(int_type, "description3"),
            },
        ),
        {
            "a": 1,
            "b": {
                "c": 2,
                "d": 3
            }
        },
        config={"a": "b"},
        k8s_pod=task.K8sPod(
            metadata=task.K8sObjectMetadata(labels={"label": "foo"},
                                            annotations={"anno": "bar"}),
            pod_spec={
                "str": "val",
                "int": 1
            },
        ),
    )
    assert obj.id.resource_type == identifier.ResourceType.TASK
    assert obj.id.project == "project"
    assert obj.id.domain == "domain"
    assert obj.id.name == "name"
    assert obj.id.version == "version"
    assert obj.type == "python"
    assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}}
    assert obj.k8s_pod.metadata == task.K8sObjectMetadata(
        labels={"label": "foo"}, annotations={"anno": "bar"})
    assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1}
    assert text_format.MessageToString(
        obj.to_flyte_idl()) == text_format.MessageToString(
            task.TaskTemplate.from_flyte_idl(
                obj.to_flyte_idl()).to_flyte_idl())
    assert obj.config == {"a": "b"}
예제 #5
0
def test_basic_workflow_promote(mock_task_fetch):
    # This section defines a sample workflow from a user
    @_sdk_tasks.inputs(a=_Types.Integer)
    @_sdk_tasks.outputs(b=_Types.Integer, c=_Types.Integer)
    @_sdk_tasks.python_task()
    def demo_task_for_promote(wf_params, a, b, c):
        b.set(a + 1)
        c.set(a + 2)

    @_sdk_workflow.workflow_class()
    class TestPromoteExampleWf(object):
        wf_input = _sdk_workflow.Input(_Types.Integer, required=True)
        my_task_node = demo_task_for_promote(a=wf_input)
        wf_output_b = _sdk_workflow.Output(my_task_node.outputs.b,
                                           sdk_type=_Types.Integer)
        wf_output_c = _sdk_workflow.Output(my_task_node.outputs.c,
                                           sdk_type=_Types.Integer)

    # This section uses the TaskTemplate stored in Admin to promote back to an Sdk Workflow
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    task_interface = _interface.TypedInterface(
        # inputs
        {'a': _interface.Variable(int_type, "description1")},
        # outputs
        {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })
    # Since the promotion of a workflow requires retrieving the task from Admin, we mock the SdkTask to return
    task_template = _task_model.TaskTemplate(_identifier.Identifier(
        _identifier.ResourceType.TASK, "project", "domain",
        "tests.flytekit.unit.common_tests.test_workflow_promote.demo_task_for_promote",
        "version"),
                                             "python_container",
                                             get_sample_task_metadata(),
                                             task_interface,
                                             custom={},
                                             container=get_sample_container())
    sdk_promoted_task = _task.SdkTask.promote_from_model(task_template)
    mock_task_fetch.return_value = sdk_promoted_task
    workflow_template = get_workflow_template()
    promoted_wf = _workflow_common.SdkWorkflow.promote_from_model(
        workflow_template)

    assert promoted_wf.interface.inputs[
        "wf_input"] == TestPromoteExampleWf.interface.inputs["wf_input"]
    assert promoted_wf.interface.outputs[
        "wf_output_b"] == TestPromoteExampleWf.interface.outputs["wf_output_b"]
    assert promoted_wf.interface.outputs[
        "wf_output_c"] == TestPromoteExampleWf.interface.outputs["wf_output_c"]

    assert len(promoted_wf.nodes) == 1
    assert len(TestPromoteExampleWf.nodes) == 1
    assert promoted_wf.nodes[0].inputs[0] == TestPromoteExampleWf.nodes[
        0].inputs[0]
예제 #6
0
def get_serializable_task(
    entity_mapping: OrderedDict,
    settings: SerializationSettings,
    entity: FlyteLocalEntity,
    fast: bool,
) -> task_models.TaskSpec:
    task_id = _identifier_model.Identifier(
        _identifier_model.ResourceType.TASK,
        settings.project,
        settings.domain,
        entity.name,
        settings.version,
    )
    tt = task_models.TaskTemplate(
        id=task_id,
        type=entity.task_type,
        metadata=entity.metadata.to_taskmetadata_model(),
        interface=entity.interface,
        custom=entity.get_custom(settings),
        container=entity.get_container(settings),
        task_type_version=entity.task_type_version,
        security_context=entity.security_context,
        config=entity.get_config(settings),
        k8s_pod=entity.get_k8s_pod(settings),
    )

    # For fast registration, we'll need to muck with the command, but only for certain kinds of tasks. Specifically,
    # tasks that rely on user code defined in the container. This should be encapsulated by the auto container
    # parent class
    if fast and isinstance(entity, PythonAutoContainerTask):
        args = [
            "pyflyte-fast-execute",
            "--additional-distribution",
            "{{ .remote_package_path }}",
            "--dest-dir",
            "{{ .dest_dir }}",
            "--",
        ] + tt.container.args[:]

        del tt.container.args[:]
        tt.container.args.extend(args)

    return task_models.TaskSpec(template=tt)
 def serialize_to_model(
         self, settings: SerializationSettings) -> _task_model.TaskTemplate:
     # This doesn't get called from translator unfortunately. Will need to move the translator to use the model
     # objects directly first.
     # Note: This doesn't settle the issue of duplicate registrations. We'll need to figure that out somehow.
     # TODO: After new control plane classes are in, promote the template to a FlyteTask, so that authors of
     #  customized-container tasks have a familiar thing to work with.
     obj = _task_model.TaskTemplate(
         identifier_models.Identifier(identifier_models.ResourceType.TASK,
                                      settings.project, settings.domain,
                                      self.name, settings.version),
         self.task_type,
         self.metadata.to_taskmetadata_model(),
         self.interface,
         self.get_custom(settings),
         container=self.get_container(settings),
         config=self.get_config(settings),
     )
     self._task_template = obj
     return obj
예제 #8
0
def test_task_template_security_context(sec_ctx):
    obj = task.TaskTemplate(
        identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"),
        "python",
        parameterizers.LIST_OF_TASK_METADATA[0],
        parameterizers.LIST_OF_INTERFACES[0],
        {"a": 1, "b": {"c": 2, "d": 3}},
        container=task.Container(
            "my_image",
            ["this", "is", "a", "cmd"],
            ["this", "is", "an", "arg"],
            parameterizers.LIST_OF_RESOURCES[0],
            {"a": "b"},
            {"d": "e"},
        ),
        security_context=sec_ctx,
    )
    assert obj.security_context == sec_ctx
    assert text_format.MessageToString(obj.to_flyte_idl()) == text_format.MessageToString(
        task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl()
    )
예제 #9
0
def test_workflow_closure():
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })

    b0 = _literals.Binding(
        'a',
        _literals.BindingData(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=5))))
    b1 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'b')))
    b2 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'c')))

    node_metadata = _workflow.NodeMetadata(name='node1',
                                           timeout=timedelta(seconds=10),
                                           retries=_literals.RetryStrategy(0))

    task_metadata = _task.TaskMetadata(
        True,
        _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                              "1.0.0", "python"), timedelta(days=1),
        _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!")

    cpu_resource = _task.Resources.ResourceEntry(
        _task.Resources.ResourceName.CPU, "1")
    resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])

    task = _task.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "project",
                               "domain", "name", "version"),
        "python",
        task_metadata,
        typed_interface, {
            'a': 1,
            'b': {
                'c': 2,
                'd': 3
            }
        },
        container=_task.Container("my_image", ["this", "is", "a", "cmd"],
                                  ["this", "is", "an", "arg"], resources, {},
                                  {}))

    task_node = _workflow.TaskNode(task.id)
    node = _workflow.Node(id='my_node',
                          metadata=node_metadata,
                          inputs=[b0],
                          upstream_node_ids=[],
                          output_aliases=[],
                          task_node=task_node)

    template = _workflow.WorkflowTemplate(
        id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project",
                                  "domain", "name", "version"),
        metadata=_workflow.WorkflowMetadata(),
        interface=typed_interface,
        nodes=[node],
        outputs=[b1, b2],
    )

    obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task])
    assert len(obj.tasks) == 1

    obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
예제 #10
0
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import literals as _literals
from flytekit.models import task as _task
from flytekit.models.core import identifier as _identifier
from flytekit.models.core import workflow as _workflow
from tests.flytekit.common import parameterizers

LIST_OF_DYNAMIC_TASKS = [
    _task.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "p", "d", "n",
                               "v"),
        "python",
        task_metadata,
        interfaces,
        _array_job.ArrayJob(2, 2, 2).to_dict(),
        container=_task.Container(
            "my_image",
            ["this", "is", "a", "cmd"],
            ["this", "is", "an", "arg"],
            resources,
            {"a": "b"},
            {"d": "e"},
        ),
    ) for task_metadata, interfaces, resources in product(
        parameterizers.LIST_OF_TASK_METADATA,
        parameterizers.LIST_OF_INTERFACES,
        parameterizers.LIST_OF_RESOURCES,
    )
]


@pytest.mark.parametrize("task", LIST_OF_DYNAMIC_TASKS)
예제 #11
0
    product([True, False], LIST_OF_RUNTIME_METADATA,
            [timedelta(days=i) for i in range(3)], LIST_OF_RETRY_POLICIES,
            LIST_OF_INTERRUPTIBLE, ["1.0"], ["deprecated"])
]

LIST_OF_TASK_TEMPLATES = [
    task.TaskTemplate(identifier.Identifier(identifier.ResourceType.TASK,
                                            "project", "domain", "name",
                                            "version"),
                      "python",
                      task_metadata,
                      interfaces, {
                          'a': 1,
                          'b': [1, 2, 3],
                          'c': 'abc',
                          'd': {
                              'x': 1,
                              'y': 2,
                              'z': 3
                          }
                      },
                      container=task.Container("my_image",
                                               ["this", "is", "a", "cmd"],
                                               ["this", "is", "an", "arg"],
                                               resources, {'a': 'b'},
                                               {'d': 'e'}))
    for task_metadata, interfaces, resources in product(
        LIST_OF_TASK_METADATA, LIST_OF_INTERFACES, LIST_OF_RESOURCES)
]

LIST_OF_CONTAINERS = [
    task.Container("my_image", ["this", "is", "a", "cmd"],
예제 #12
0
]

LIST_OF_TASK_TEMPLATES = [
    task.TaskTemplate(
        identifier.Identifier(identifier.ResourceType.TASK, "project",
                              "domain", "name", "version"),
        "python",
        task_metadata,
        interfaces,
        {
            "a": 1,
            "b": [1, 2, 3],
            "c": "abc",
            "d": {
                "x": 1,
                "y": 2,
                "z": 3
            }
        },
        container=task.Container(
            "my_image",
            ["this", "is", "a", "cmd"],
            ["this", "is", "an", "arg"],
            resources,
            {"a": "b"},
            {"d": "e"},
        ),
    ) for task_metadata, interfaces, resources in product(
        LIST_OF_TASK_METADATA, LIST_OF_INTERFACES, LIST_OF_RESOURCES)
]
예제 #13
0
import pytest
from datetime import timedelta as _timedelta
from google.protobuf import text_format

from flytekit.models import literals as _literals, dynamic_job as _dynamic_job, array_job as _array_job, \
    task as _task
from flytekit.models.core import workflow as _workflow, identifier as _identifier
from tests.flytekit.common import parameterizers

LIST_OF_DYNAMIC_TASKS = [
    _task.TaskTemplate(_identifier.Identifier(_identifier.ResourceType.TASK,
                                              "p", "d", "n", "v"),
                       "python",
                       task_metadata,
                       interfaces,
                       _array_job.ArrayJob(2, 2, 2).to_dict(),
                       container=_task.Container("my_image",
                                                 ["this", "is", "a", "cmd"],
                                                 ["this", "is", "an", "arg"],
                                                 resources, {'a': 'b'},
                                                 {'d': 'e'}))
    for task_metadata, interfaces, resources in product(
        parameterizers.LIST_OF_TASK_METADATA,
        parameterizers.LIST_OF_INTERFACES, parameterizers.LIST_OF_RESOURCES)
]


@pytest.mark.parametrize("task", LIST_OF_DYNAMIC_TASKS)
def test_future_task_document(task):
    rs = _literals.RetryStrategy(0)
    nm = _workflow.NodeMetadata('node-name', _timedelta(minutes=10), rs)
예제 #14
0
def get_serializable_task(
    entity_mapping: OrderedDict,
    settings: SerializationSettings,
    entity: FlyteLocalEntity,
) -> task_models.TaskSpec:
    task_id = _identifier_model.Identifier(
        _identifier_model.ResourceType.TASK,
        settings.project,
        settings.domain,
        entity.name,
        settings.version,
    )

    if isinstance(
            entity, PythonFunctionTask
    ) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC:
        # In case of Dynamic tasks, we want to pass the serialization context, so that they can reconstruct the state
        # from the serialization context. This is passed through an environment variable, that is read from
        # during dynamic serialization
        settings = settings.with_serialized_context()

    container = entity.get_container(settings)
    # This pod will be incorrect when doing fast serialize
    pod = entity.get_k8s_pod(settings)

    if settings.should_fast_serialize():
        # This handles container tasks.
        if container and isinstance(entity,
                                    (PythonAutoContainerTask, MapPythonTask)):
            # For fast registration, we'll need to muck with the command, but on
            # ly for certain kinds of tasks. Specifically,
            # tasks that rely on user code defined in the container. This should be encapsulated by the auto container
            # parent class
            container._args = prefix_with_fast_execute(settings,
                                                       container.args)

        # If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect.
        # The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because
        # the pod spec is a K8s library object, and we shouldn't be messing around with it in this file.
        elif pod:
            if isinstance(entity, MapPythonTask):
                entity.set_command_prefix(
                    get_command_prefix_for_fast_execute(settings))
                pod = entity.get_k8s_pod(settings)
            else:
                entity.set_command_fn(
                    _fast_serialize_command_fn(settings, entity))
                pod = entity.get_k8s_pod(settings)
                entity.reset_command_fn()

    tt = task_models.TaskTemplate(
        id=task_id,
        type=entity.task_type,
        metadata=entity.metadata.to_taskmetadata_model(),
        interface=entity.interface,
        custom=entity.get_custom(settings),
        container=container,
        task_type_version=entity.task_type_version,
        security_context=entity.security_context,
        config=entity.get_config(settings),
        k8s_pod=pod,
        sql=entity.get_sql(settings),
    )
    if settings.should_fast_serialize() and isinstance(
            entity, PythonAutoContainerTask):
        entity.reset_command_fn()
    return task_models.TaskSpec(template=tt)