def test_task_template_security_context(sec_ctx): obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", parameterizers.LIST_OF_TASK_METADATA[0], parameterizers.LIST_OF_INTERFACES[0], { "a": 1, "b": { "c": 2, "d": 3 } }, container=task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], parameterizers.LIST_OF_RESOURCES[0], {"a": "b"}, {"d": "e"}, ), security_context=sec_ctx, ) assert obj.security_context == sec_ctx expected = obj.security_context if sec_ctx: if sec_ctx.run_as is None and sec_ctx.secrets is None and sec_ctx.tokens is None: expected = None assert task.TaskTemplate.from_flyte_idl( obj.to_flyte_idl()).security_context == expected
def test_task_template(in_tuple): task_metadata, interfaces, resources = in_tuple obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, interfaces, { 'a': 1, 'b': { 'c': 2, 'd': 3 } }, container=task.Container("my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {'a': 'b'}, {'d': 'e'})) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" assert obj.id.domain == "domain" assert obj.id.name == "name" assert obj.id.version == "version" assert obj.type == "python" assert obj.metadata == task_metadata assert obj.interface == interfaces assert obj.custom == {'a': 1, 'b': {'c': 2, 'd': 3}} assert obj.container.image == "my_image" assert obj.container.resources == resources assert text_format.MessageToString( obj.to_flyte_idl()) == text_format.MessageToString( task.TaskTemplate.from_flyte_idl( obj.to_flyte_idl()).to_flyte_idl())
def get_serializable_task( entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, ) -> task_models.TaskSpec: task_id = _identifier_model.Identifier( _identifier_model.ResourceType.TASK, settings.project, settings.domain, entity.name, settings.version, ) if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask): # For fast registration, we'll need to muck with the command, but only for certain kinds of tasks. Specifically, # tasks that rely on user code defined in the container. This should be encapsulated by the auto container # parent class entity.set_command_fn(_fast_serialize_command_fn(settings, entity)) tt = task_models.TaskTemplate( id=task_id, type=entity.task_type, metadata=entity.metadata.to_taskmetadata_model(), interface=entity.interface, custom=entity.get_custom(settings), container=entity.get_container(settings), task_type_version=entity.task_type_version, security_context=entity.security_context, config=entity.get_config(settings), k8s_pod=entity.get_k8s_pod(settings), ) if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask): entity.reset_command_fn() return task_models.TaskSpec(template=tt)
def test_task_template__k8s_pod_target(): int_type = types.LiteralType(types.SimpleType.INTEGER) obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task.TaskMetadata( False, task.RuntimeMetadata(1, "v", "f"), timedelta(days=1), literal_models.RetryStrategy(5), False, "1.0", "deprecated", False, ), interface_models.TypedInterface( # inputs {"a": interface_models.Variable(int_type, "description1")}, # outputs { "b": interface_models.Variable(int_type, "description2"), "c": interface_models.Variable(int_type, "description3"), }, ), { "a": 1, "b": { "c": 2, "d": 3 } }, config={"a": "b"}, k8s_pod=task.K8sPod( metadata=task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}), pod_spec={ "str": "val", "int": 1 }, ), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" assert obj.id.domain == "domain" assert obj.id.name == "name" assert obj.id.version == "version" assert obj.type == "python" assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}} assert obj.k8s_pod.metadata == task.K8sObjectMetadata( labels={"label": "foo"}, annotations={"anno": "bar"}) assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1} assert text_format.MessageToString( obj.to_flyte_idl()) == text_format.MessageToString( task.TaskTemplate.from_flyte_idl( obj.to_flyte_idl()).to_flyte_idl()) assert obj.config == {"a": "b"}
def test_basic_workflow_promote(mock_task_fetch): # This section defines a sample workflow from a user @_sdk_tasks.inputs(a=_Types.Integer) @_sdk_tasks.outputs(b=_Types.Integer, c=_Types.Integer) @_sdk_tasks.python_task() def demo_task_for_promote(wf_params, a, b, c): b.set(a + 1) c.set(a + 2) @_sdk_workflow.workflow_class() class TestPromoteExampleWf(object): wf_input = _sdk_workflow.Input(_Types.Integer, required=True) my_task_node = demo_task_for_promote(a=wf_input) wf_output_b = _sdk_workflow.Output(my_task_node.outputs.b, sdk_type=_Types.Integer) wf_output_c = _sdk_workflow.Output(my_task_node.outputs.c, sdk_type=_Types.Integer) # This section uses the TaskTemplate stored in Admin to promote back to an Sdk Workflow int_type = _types.LiteralType(_types.SimpleType.INTEGER) task_interface = _interface.TypedInterface( # inputs {'a': _interface.Variable(int_type, "description1")}, # outputs { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) # Since the promotion of a workflow requires retrieving the task from Admin, we mock the SdkTask to return task_template = _task_model.TaskTemplate(_identifier.Identifier( _identifier.ResourceType.TASK, "project", "domain", "tests.flytekit.unit.common_tests.test_workflow_promote.demo_task_for_promote", "version"), "python_container", get_sample_task_metadata(), task_interface, custom={}, container=get_sample_container()) sdk_promoted_task = _task.SdkTask.promote_from_model(task_template) mock_task_fetch.return_value = sdk_promoted_task workflow_template = get_workflow_template() promoted_wf = _workflow_common.SdkWorkflow.promote_from_model( workflow_template) assert promoted_wf.interface.inputs[ "wf_input"] == TestPromoteExampleWf.interface.inputs["wf_input"] assert promoted_wf.interface.outputs[ "wf_output_b"] == TestPromoteExampleWf.interface.outputs["wf_output_b"] assert promoted_wf.interface.outputs[ "wf_output_c"] == TestPromoteExampleWf.interface.outputs["wf_output_c"] assert len(promoted_wf.nodes) == 1 assert len(TestPromoteExampleWf.nodes) == 1 assert promoted_wf.nodes[0].inputs[0] == TestPromoteExampleWf.nodes[ 0].inputs[0]
def get_serializable_task( entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, ) -> task_models.TaskSpec: task_id = _identifier_model.Identifier( _identifier_model.ResourceType.TASK, settings.project, settings.domain, entity.name, settings.version, ) tt = task_models.TaskTemplate( id=task_id, type=entity.task_type, metadata=entity.metadata.to_taskmetadata_model(), interface=entity.interface, custom=entity.get_custom(settings), container=entity.get_container(settings), task_type_version=entity.task_type_version, security_context=entity.security_context, config=entity.get_config(settings), k8s_pod=entity.get_k8s_pod(settings), ) # For fast registration, we'll need to muck with the command, but only for certain kinds of tasks. Specifically, # tasks that rely on user code defined in the container. This should be encapsulated by the auto container # parent class if fast and isinstance(entity, PythonAutoContainerTask): args = [ "pyflyte-fast-execute", "--additional-distribution", "{{ .remote_package_path }}", "--dest-dir", "{{ .dest_dir }}", "--", ] + tt.container.args[:] del tt.container.args[:] tt.container.args.extend(args) return task_models.TaskSpec(template=tt)
def serialize_to_model( self, settings: SerializationSettings) -> _task_model.TaskTemplate: # This doesn't get called from translator unfortunately. Will need to move the translator to use the model # objects directly first. # Note: This doesn't settle the issue of duplicate registrations. We'll need to figure that out somehow. # TODO: After new control plane classes are in, promote the template to a FlyteTask, so that authors of # customized-container tasks have a familiar thing to work with. obj = _task_model.TaskTemplate( identifier_models.Identifier(identifier_models.ResourceType.TASK, settings.project, settings.domain, self.name, settings.version), self.task_type, self.metadata.to_taskmetadata_model(), self.interface, self.get_custom(settings), container=self.get_container(settings), config=self.get_config(settings), ) self._task_template = obj return obj
def test_task_template_security_context(sec_ctx): obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", parameterizers.LIST_OF_TASK_METADATA[0], parameterizers.LIST_OF_INTERFACES[0], {"a": 1, "b": {"c": 2, "d": 3}}, container=task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], parameterizers.LIST_OF_RESOURCES[0], {"a": "b"}, {"d": "e"}, ), security_context=sec_ctx, ) assert obj.security_context == sec_ctx assert text_format.MessageToString(obj.to_flyte_idl()) == text_format.MessageToString( task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() )
def test_workflow_closure(): int_type = _types.LiteralType(_types.SimpleType.INTEGER) typed_interface = _interface.TypedInterface( {'a': _interface.Variable(int_type, "description1")}, { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) b0 = _literals.Binding( 'a', _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5)))) b1 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'b'))) b2 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'c'))) node_metadata = _workflow.NodeMetadata(name='node1', timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) task_metadata = _task.TaskMetadata( True, _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!") cpu_resource = _task.Resources.ResourceEntry( _task.Resources.ResourceName.CPU, "1") resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource]) task = _task.TaskTemplate( _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, typed_interface, { 'a': 1, 'b': { 'c': 2, 'd': 3 } }, container=_task.Container("my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {})) task_node = _workflow.TaskNode(task.id) node = _workflow.Node(id='my_node', metadata=node_metadata, inputs=[b0], upstream_node_ids=[], output_aliases=[], task_node=task_node) template = _workflow.WorkflowTemplate( id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version"), metadata=_workflow.WorkflowMetadata(), interface=typed_interface, nodes=[node], outputs=[b1, b2], ) obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task]) assert len(obj.tasks) == 1 obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2
from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import literals as _literals from flytekit.models import task as _task from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _workflow from tests.flytekit.common import parameterizers LIST_OF_DYNAMIC_TASKS = [ _task.TaskTemplate( _identifier.Identifier(_identifier.ResourceType.TASK, "p", "d", "n", "v"), "python", task_metadata, interfaces, _array_job.ArrayJob(2, 2, 2).to_dict(), container=_task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, ), ) for task_metadata, interfaces, resources in product( parameterizers.LIST_OF_TASK_METADATA, parameterizers.LIST_OF_INTERFACES, parameterizers.LIST_OF_RESOURCES, ) ] @pytest.mark.parametrize("task", LIST_OF_DYNAMIC_TASKS)
product([True, False], LIST_OF_RUNTIME_METADATA, [timedelta(days=i) for i in range(3)], LIST_OF_RETRY_POLICIES, LIST_OF_INTERRUPTIBLE, ["1.0"], ["deprecated"]) ] LIST_OF_TASK_TEMPLATES = [ task.TaskTemplate(identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, interfaces, { 'a': 1, 'b': [1, 2, 3], 'c': 'abc', 'd': { 'x': 1, 'y': 2, 'z': 3 } }, container=task.Container("my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {'a': 'b'}, {'d': 'e'})) for task_metadata, interfaces, resources in product( LIST_OF_TASK_METADATA, LIST_OF_INTERFACES, LIST_OF_RESOURCES) ] LIST_OF_CONTAINERS = [ task.Container("my_image", ["this", "is", "a", "cmd"],
] LIST_OF_TASK_TEMPLATES = [ task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, interfaces, { "a": 1, "b": [1, 2, 3], "c": "abc", "d": { "x": 1, "y": 2, "z": 3 } }, container=task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, ), ) for task_metadata, interfaces, resources in product( LIST_OF_TASK_METADATA, LIST_OF_INTERFACES, LIST_OF_RESOURCES) ]
import pytest from datetime import timedelta as _timedelta from google.protobuf import text_format from flytekit.models import literals as _literals, dynamic_job as _dynamic_job, array_job as _array_job, \ task as _task from flytekit.models.core import workflow as _workflow, identifier as _identifier from tests.flytekit.common import parameterizers LIST_OF_DYNAMIC_TASKS = [ _task.TaskTemplate(_identifier.Identifier(_identifier.ResourceType.TASK, "p", "d", "n", "v"), "python", task_metadata, interfaces, _array_job.ArrayJob(2, 2, 2).to_dict(), container=_task.Container("my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {'a': 'b'}, {'d': 'e'})) for task_metadata, interfaces, resources in product( parameterizers.LIST_OF_TASK_METADATA, parameterizers.LIST_OF_INTERFACES, parameterizers.LIST_OF_RESOURCES) ] @pytest.mark.parametrize("task", LIST_OF_DYNAMIC_TASKS) def test_future_task_document(task): rs = _literals.RetryStrategy(0) nm = _workflow.NodeMetadata('node-name', _timedelta(minutes=10), rs)
def get_serializable_task( entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, ) -> task_models.TaskSpec: task_id = _identifier_model.Identifier( _identifier_model.ResourceType.TASK, settings.project, settings.domain, entity.name, settings.version, ) if isinstance( entity, PythonFunctionTask ) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC: # In case of Dynamic tasks, we want to pass the serialization context, so that they can reconstruct the state # from the serialization context. This is passed through an environment variable, that is read from # during dynamic serialization settings = settings.with_serialized_context() container = entity.get_container(settings) # This pod will be incorrect when doing fast serialize pod = entity.get_k8s_pod(settings) if settings.should_fast_serialize(): # This handles container tasks. if container and isinstance(entity, (PythonAutoContainerTask, MapPythonTask)): # For fast registration, we'll need to muck with the command, but on # ly for certain kinds of tasks. Specifically, # tasks that rely on user code defined in the container. This should be encapsulated by the auto container # parent class container._args = prefix_with_fast_execute(settings, container.args) # If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect. # The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because # the pod spec is a K8s library object, and we shouldn't be messing around with it in this file. elif pod: if isinstance(entity, MapPythonTask): entity.set_command_prefix( get_command_prefix_for_fast_execute(settings)) pod = entity.get_k8s_pod(settings) else: entity.set_command_fn( _fast_serialize_command_fn(settings, entity)) pod = entity.get_k8s_pod(settings) entity.reset_command_fn() tt = task_models.TaskTemplate( id=task_id, type=entity.task_type, metadata=entity.metadata.to_taskmetadata_model(), interface=entity.interface, custom=entity.get_custom(settings), container=container, task_type_version=entity.task_type_version, security_context=entity.security_context, config=entity.get_config(settings), k8s_pod=pod, sql=entity.get_sql(settings), ) if settings.should_fast_serialize() and isinstance( entity, PythonAutoContainerTask): entity.reset_command_fn() return task_models.TaskSpec(template=tt)