def test_execution_annotation_overrides(mock_client_factory): mock_client = MagicMock() mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier("xp", "xd", "xn")) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( return_value=identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version") ) annotations = _common_models.Annotations({"my": "annotation"}) engine.FlyteLaunchPlan(m).launch( "xp", "xd", "xn", literals.LiteralMap({}), notification_overrides=[], annotation_overrides=annotations, ) mock_client.create_execution.assert_called_once_with( "xp", "xd", "xn", _execution_models.ExecutionSpec( identifier.Identifier( identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version", ), _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), disable_all=True, annotations=annotations, ), literals.LiteralMap({}), )
def test_basic_workflow_promote(mock_task_fetch): # This section defines a sample workflow from a user @_sdk_tasks.inputs(a=_Types.Integer) @_sdk_tasks.outputs(b=_Types.Integer, c=_Types.Integer) @_sdk_tasks.python_task() def demo_task_for_promote(wf_params, a, b, c): b.set(a + 1) c.set(a + 2) @_sdk_workflow.workflow_class() class TestPromoteExampleWf(object): wf_input = _sdk_workflow.Input(_Types.Integer, required=True) my_task_node = demo_task_for_promote(a=wf_input) wf_output_b = _sdk_workflow.Output(my_task_node.outputs.b, sdk_type=_Types.Integer) wf_output_c = _sdk_workflow.Output(my_task_node.outputs.c, sdk_type=_Types.Integer) # This section uses the TaskTemplate stored in Admin to promote back to an Sdk Workflow int_type = _types.LiteralType(_types.SimpleType.INTEGER) task_interface = _interface.TypedInterface( # inputs {'a': _interface.Variable(int_type, "description1")}, # outputs { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") } ) # Since the promotion of a workflow requires retrieving the task from Admin, we mock the SdkTask to return task_template = _task_model.TaskTemplate( _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "tests.flytekit.unit.common_tests.test_workflow_promote.demo_task_for_promote", "version"), "python_container", get_sample_task_metadata(), task_interface, custom={}, container=get_sample_container() ) sdk_promoted_task = _task.SdkTask.promote_from_model(task_template) mock_task_fetch.return_value = sdk_promoted_task workflow_template = get_workflow_template() promoted_wf = _workflow_common.SdkWorkflow.promote_from_model(workflow_template) assert promoted_wf.interface.inputs["wf_input"] == TestPromoteExampleWf.interface.inputs["wf_input"] assert promoted_wf.interface.outputs["wf_output_b"] == TestPromoteExampleWf.interface.outputs["wf_output_b"] assert promoted_wf.interface.outputs["wf_output_c"] == TestPromoteExampleWf.interface.outputs["wf_output_c"] assert len(promoted_wf.nodes) == 1 assert len(TestPromoteExampleWf.nodes) == 1 assert promoted_wf.nodes[0].inputs[0] == TestPromoteExampleWf.nodes[0].inputs[0]
def activate_all_impl(project, domain, version, pkgs, ignore_schedules=False): # TODO: This should be a transaction to ensure all or none are updated # TODO: We should optionally allow deactivation of missing launch plans # Discover all launch plans by loading the modules for m, k, lp in iterate_registerable_entities_in_order( pkgs, include_entities={_SdkLaunchPlan}, detect_unreferenced_entities=False): lp._id = _identifier.Identifier( _identifier.ResourceType.LAUNCH_PLAN, project, domain, _utils.fqdn(m.__name__, k, entity_type=lp.resource_type), version) if not (lp.is_scheduled and ignore_schedules): lp.update(_launch_plan_model.LaunchPlanState.ACTIVE)
def test_task_execution_identifier(): task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") node_exec_id = identifier.NodeExecutionIdentifier("node_id", wf_exec_id,) obj = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3) assert obj.retry_attempt == 3 assert obj.task_id == task_id assert obj.node_execution_id == node_exec_id obj2 = identifier.TaskExecutionIdentifier.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj assert obj2.retry_attempt == 3 assert obj2.task_id == task_id assert obj2.node_execution_id == node_exec_id
def test_identifier(): obj = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") assert obj.project == "project" assert obj.domain == "domain" assert obj.name == "name" assert obj.version == "version" assert obj.resource_type == identifier.ResourceType.TASK obj2 = identifier.Identifier.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj assert obj2.project == "project" assert obj2.domain == "domain" assert obj2.name == "name" assert obj2.version == "version" assert obj2.resource_type == identifier.ResourceType.TASK
def test_flyte_workflow_integration(mock_url, mock_client_manager): mock_url.get.return_value = "localhost" admin_workflow = _workflow_models.Workflow( _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p1", "d1", "n1", "v1"), _MagicMock(), ) mock_client = _MagicMock() mock_client.list_workflows_paginated = _MagicMock( returnValue=([admin_workflow], "")) mock_client_manager.return_value.client = mock_client workflow = _workflow.FlyteWorkflow.fetch("p1", "d1", "n1", "v1") assert workflow.entity_type_text == "Workflow" assert workflow.id == admin_workflow.id
def test_flyte_task_fetch(mock_url, mock_client_manager): mock_url.get.return_value = "localhost" admin_task_v1 = _task_models.Task( _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), _MagicMock(), ) admin_task_v2 = _task_models.Task( _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v2"), _MagicMock(), ) mock_client = _MagicMock() mock_client.list_tasks_paginated = _MagicMock(return_value=([admin_task_v2, admin_task_v1], "")) mock_client_manager.return_value.client = mock_client latest_task = _task.FlyteTask.fetch_latest("p1", "d1", "n1") task_v1 = _task.FlyteTask.fetch("p1", "d1", "n1", "v1") task_v2 = _task.FlyteTask.fetch("p1", "d1", "n1", "v2") assert task_v1.id == admin_task_v1.id assert task_v1.id != latest_task.id assert task_v2.id == latest_task.id == admin_task_v2.id for task in [task_v1, task_v2]: assert task.entity_type_text == "Task" assert task.resource_type == _identifier.ResourceType.TASK
def test__extract_files(load_mock): id = _core_identifier.Identifier(_core_identifier.ResourceType.TASK, 'myproject', 'development', 'name', 'v') t = get_sample_task() with TemporaryConfiguration("", internal_overrides={ 'image': 'myflyteimage:v123', 'project': 'myflyteproject', 'domain': 'development' }): task_spec = t.serialize() load_mock.side_effect = [id.to_flyte_idl(), task_spec] new_id, entity = _main._extract_pair('a', 'b') assert new_id == id.to_flyte_idl() assert task_spec == entity
def test__extract_files(load_mock): id = _core_identifier.Identifier(_core_identifier.ResourceType.TASK, "myproject", "development", "name", "v") t = get_sample_task() with TemporaryConfiguration( "", internal_overrides={ "image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development" }, ): task_spec = t.serialize() load_mock.side_effect = [id.to_flyte_idl(), task_spec] new_id, entity = _main._extract_pair("a", "b") assert new_id == id.to_flyte_idl() assert task_spec == entity
def get_serializable_task( entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, ) -> task_models.TaskSpec: task_id = _identifier_model.Identifier( _identifier_model.ResourceType.TASK, settings.project, settings.domain, entity.name, settings.version, ) tt = task_models.TaskTemplate( id=task_id, type=entity.task_type, metadata=entity.metadata.to_taskmetadata_model(), interface=entity.interface, custom=entity.get_custom(settings), container=entity.get_container(settings), task_type_version=entity.task_type_version, security_context=entity.security_context, config=entity.get_config(settings), k8s_pod=entity.get_k8s_pod(settings), ) # For fast registration, we'll need to muck with the command, but only for certain kinds of tasks. Specifically, # tasks that rely on user code defined in the container. This should be encapsulated by the auto container # parent class if fast and isinstance(entity, PythonAutoContainerTask): args = [ "pyflyte-fast-execute", "--additional-distribution", "{{ .remote_package_path }}", "--dest-dir", "{{ .dest_dir }}", "--", ] + tt.container.args[:] del tt.container.args[:] tt.container.args.extend(args) return task_models.TaskSpec(template=tt)
def test_task_template__k8s_pod_target(): int_type = types.LiteralType(types.SimpleType.INTEGER) obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task.TaskMetadata( False, task.RuntimeMetadata(1, "v", "f"), timedelta(days=1), literal_models.RetryStrategy(5), False, "1.0", "deprecated", ), interface_models.TypedInterface( # inputs {"a": interface_models.Variable(int_type, "description1")}, # outputs { "b": interface_models.Variable(int_type, "description2"), "c": interface_models.Variable(int_type, "description3"), }, ), {"a": 1, "b": {"c": 2, "d": 3}}, config={"a": "b"}, k8s_pod=task.K8sPod( metadata=task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}), pod_spec={"str": "val", "int": 1}, ), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" assert obj.id.domain == "domain" assert obj.id.name == "name" assert obj.id.version == "version" assert obj.type == "python" assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}} assert obj.k8s_pod.metadata == task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}) assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1} assert text_format.MessageToString(obj.to_flyte_idl()) == text_format.MessageToString( task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() ) assert obj.config == {"a": "b"}
def test_old_style_role(): identifier_model = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") s = schedule.Schedule("asdf", "1 3 4 5 6 7") launch_plan_metadata_model = launch_plan.LaunchPlanMetadata( schedule=s, notifications=[]) v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") p = interface.Parameter(var=v) parameter_map = interface.ParameterMap({"ppp": p}) fixed_inputs = literals.LiteralMap({ "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive( integer=1))) }) labels_model = common.Labels({}) annotations_model = common.Annotations({"my": "annotation"}) raw_data_output_config = common.RawOutputDataConfig("s3://bucket") old_role = _launch_plan_idl.Auth( kubernetes_service_account="my:service:account") old_style_spec = _launch_plan_idl.LaunchPlanSpec( workflow_id=identifier_model.to_flyte_idl(), entity_metadata=launch_plan_metadata_model.to_flyte_idl(), default_inputs=parameter_map.to_flyte_idl(), fixed_inputs=fixed_inputs.to_flyte_idl(), labels=labels_model.to_flyte_idl(), annotations=annotations_model.to_flyte_idl(), raw_output_data_config=raw_data_output_config.to_flyte_idl(), auth=old_role, ) lp_spec = launch_plan.LaunchPlanSpec.from_flyte_idl(old_style_spec) assert lp_spec.auth_role.assumable_iam_role == "my:service:account"
def serialize_to_model( self, settings: SerializationSettings) -> _task_model.TaskTemplate: # This doesn't get called from translator unfortunately. Will need to move the translator to use the model # objects directly first. # Note: This doesn't settle the issue of duplicate registrations. We'll need to figure that out somehow. # TODO: After new control plane classes are in, promote the template to a FlyteTask, so that authors of # customized-container tasks have a familiar thing to work with. obj = _task_model.TaskTemplate( identifier_models.Identifier(identifier_models.ResourceType.TASK, settings.project, settings.domain, self.name, settings.version), self.task_type, self.metadata.to_taskmetadata_model(), self.interface, self.get_custom(settings), container=self.get_container(settings), config=self.get_config(settings), ) self._task_template = obj return obj
def test_task_template(in_tuple): task_metadata, interfaces, resources = in_tuple obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, interfaces, { "a": 1, "b": { "c": 2, "d": 3 } }, container=task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, ), config={"a": "b"}, ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" assert obj.id.domain == "domain" assert obj.id.name == "name" assert obj.id.version == "version" assert obj.type == "python" assert obj.metadata == task_metadata assert obj.interface == interfaces assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}} assert obj.container.image == "my_image" assert obj.container.resources == resources assert text_format.MessageToString( obj.to_flyte_idl()) == text_format.MessageToString( task.TaskTemplate.from_flyte_idl( obj.to_flyte_idl()).to_flyte_idl()) assert obj.config == {"a": "b"}
def test_task_template_security_context(sec_ctx): obj = task.TaskTemplate( identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", parameterizers.LIST_OF_TASK_METADATA[0], parameterizers.LIST_OF_INTERFACES[0], {"a": 1, "b": {"c": 2, "d": 3}}, container=task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], parameterizers.LIST_OF_RESOURCES[0], {"a": "b"}, {"d": "e"}, ), security_context=sec_ctx, ) assert obj.security_context == sec_ctx assert text_format.MessageToString(obj.to_flyte_idl()) == text_format.MessageToString( task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() )
def get_serializable_launch_plan( entity_mapping: OrderedDict, settings: SerializationSettings, entity: LaunchPlan, fast: bool, ) -> _launch_plan_models.LaunchPlan: wf_spec = get_serializable(entity_mapping, settings, entity.workflow) lps = _launch_plan_models.LaunchPlanSpec( workflow_id=wf_spec.template.id, entity_metadata=_launch_plan_models.LaunchPlanMetadata( schedule=entity.schedule, notifications=entity.notifications, ), default_inputs=entity.parameters, fixed_inputs=entity.fixed_inputs, labels=entity.labels or _common_models.Labels({}), annotations=entity.annotations or _common_models.Annotations({}), auth_role=entity._auth_role or _common_models.AuthRole(), raw_output_data_config=entity.raw_output_data_config or _common_models.RawOutputDataConfig(""), ) lp_id = _identifier_model.Identifier( resource_type=_identifier_model.ResourceType.LAUNCH_PLAN, project=settings.project, domain=settings.domain, name=entity.name, version=settings.version, ) lp_model = _launch_plan_models.LaunchPlan( id=lp_id, spec=lps, closure=_launch_plan_models.LaunchPlanClosure( state=None, expected_inputs=interface_models.ParameterMap({}), expected_outputs=interface_models.VariableMap({}), ), ) return lp_model
def test_promote_from_model(): workflow_to_test = _workflow.workflow( {}, inputs={ 'required_input': _workflow.Input(_types.Types.Integer), 'default_input': _workflow.Input(_types.Types.Integer, default=5) }) workflow_to_test._id = _identifier.Identifier( _identifier.ResourceType.WORKFLOW, "p", "d", "n", "v") lp = workflow_to_test.create_launch_plan( fixed_inputs={'required_input': 5}, schedule=_schedules.CronSchedule("* * ? * * *"), role='what', labels=_common_models.Labels({"my": "label"})) with _pytest.raises(_user_exceptions.FlyteAssertion): _launch_plan.SdkRunnableLaunchPlan.from_flyte_idl(lp.to_flyte_idl()) lp_from_spec = _launch_plan.SdkLaunchPlan.from_flyte_idl(lp.to_flyte_idl()) assert not isinstance(lp_from_spec, _launch_plan.SdkRunnableLaunchPlan) assert isinstance(lp_from_spec, _launch_plan.SdkLaunchPlan) assert lp_from_spec == lp
def initialize(): """ Re-initializes the context and erases the entire context """ # This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally default_execution_id = _identifier.WorkflowExecutionIdentifier( project="local", domain="local", name="local") cfg = Config.auto() # Ensure a local directory is available for users to work with. user_space_path = os.path.join(cfg.local_sandbox_path, "user_space") pathlib.Path(user_space_path).mkdir(parents=True, exist_ok=True) # Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users # are already acquainted with default_context = FlyteContext( file_access=default_local_file_access_provider) default_user_space_params = ExecutionParameters( execution_id=WorkflowExecutionIdentifier.promote_from_model( default_execution_id), task_id=_identifier.Identifier(_identifier.ResourceType.TASK, "local", "local", "local", "local"), execution_date=_datetime.datetime.utcnow(), stats=mock_stats.MockStats(), logging=user_space_logger, tmp_dir=user_space_path, raw_output_prefix=default_context.file_access._raw_output_prefix, decks=[], ) default_context = default_context.with_execution_state( default_context.new_execution_state().with_params( user_space_params=default_user_space_params)).build() default_context.set_stackframe( s=FlyteContextManager.get_origin_stackframe()) flyte_context_Var.set([default_context])
from flytekit.models import task as _task_models from flytekit.models import types as _type_models from flytekit.models.core import identifier as _identifier from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.types import Types @inputs(in1=Types.Integer) @outputs(out1=Types.String) @python_task def default_task(wf_params, in1, out1): pass default_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version") def test_default_python_task(): assert isinstance(default_task, _sdk_runnable.SdkRunnableTask) assert default_task.interface.inputs["in1"].description == "" assert default_task.interface.inputs[ "in1"].type == _type_models.LiteralType( simple=_type_models.SimpleType.INTEGER) assert default_task.interface.outputs["out1"].description == "" assert default_task.interface.outputs[ "out1"].type == _type_models.LiteralType( simple=_type_models.SimpleType.STRING) assert default_task.type == _common_constants.SdkTaskType.PYTHON_TASK assert default_task.task_function_name == "default_task"
@inputs(in1=Types.Integer) @outputs(out1=Types.String) @sidecar_task( cpu_request='10', gpu_limit='2', environment={"foo": "bar"}, pod_spec=get_pod_spec(), primary_container_name="a container", ) def simple_sidecar_task(wf_params, in1, out1): pass simple_sidecar_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version") def test_sidecar_task(): assert isinstance(simple_sidecar_task, _sdk_task.SdkTask) assert isinstance(simple_sidecar_task, _sidecar_task.SdkSidecarTask) pod_spec = simple_sidecar_task.custom['podSpec'] assert pod_spec['restartPolicy'] == 'OnFailure' assert len(pod_spec['containers']) == 2 primary_container = pod_spec['containers'][0] assert primary_container['name'] == 'a container' assert primary_container['args'] == [ 'pyflyte-execute', '--task-module', 'tests.flytekit.unit.sdk.tasks.test_sidecar_tasks', '--task-name',
from datetime import timedelta from flytekit.models import interface as _interface from flytekit.models import literals as _literals from flytekit.models import types as _types from flytekit.models.core import condition as _condition from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _workflow _generic_id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version") def test_node_metadata(): obj = _workflow.NodeMetadata(name="node1", timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) assert obj.timeout.seconds == 10 assert obj.retries.retries == 0 obj2 = _workflow.NodeMetadata.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.timeout.seconds == 10 assert obj2.retries.retries == 0 def test_alias(): obj = _workflow.Alias(var="myvar", alias="myalias") assert obj.alias == "myalias" assert obj.var == "myvar" obj2 = _workflow.Alias.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2
from itertools import product import pytest from google.protobuf import text_format from flytekit.models import array_job as _array_job from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import literals as _literals from flytekit.models import task as _task from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _workflow from tests.flytekit.common import parameterizers LIST_OF_DYNAMIC_TASKS = [ _task.TaskTemplate( _identifier.Identifier(_identifier.ResourceType.TASK, "p", "d", "n", "v"), "python", task_metadata, interfaces, _array_job.ArrayJob(2, 2, 2).to_dict(), container=_task.Container( "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, ), ) for task_metadata, interfaces, resources in product( parameterizers.LIST_OF_TASK_METADATA, parameterizers.LIST_OF_INTERFACES,
def test_workflow(): @inputs(a=primitives.Integer) @outputs(b=primitives.Integer) @python_task() def my_task(wf_params, a, b): b.set(a + 1) my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_task", "version") @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @python_task def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_list_task", "version") input_list = [ promise.Input("input_1", primitives.Integer), promise.Input("input_2", primitives.Integer, default=5, help="Not required."), ] n1 = my_task(a=input_list[0]).assign_id_and_return("n1") n2 = my_task(a=input_list[1]).assign_id_and_return("n2") n3 = my_task(a=100).assign_id_and_return("n3") n4 = my_task(a=n1.outputs.b).assign_id_and_return("n4") n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100 ]).assign_id_and_return("n5") n6 = my_list_task(a=n5.outputs.b) n1 >> n6 nodes = [n1, n2, n3, n4, n5, n6] w = _local_workflow.SdkRunnableWorkflow.construct_from_class_definition( inputs=input_list, outputs=[ _local_workflow.Output("a", n1.outputs.b, sdk_type=primitives.Integer) ], nodes=nodes, ) assert w.interface.inputs[ "input_1"].type == primitives.Integer.to_flyte_literal_type() assert w.interface.inputs[ "input_2"].type == primitives.Integer.to_flyte_literal_type() assert w.nodes[0].inputs[0].var == "a" assert w.nodes[0].inputs[ 0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[0].inputs[0].binding.promise.var == "input_1" assert w.nodes[1].inputs[ 0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[1].inputs[0].binding.promise.var == "input_2" assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 assert w.nodes[3].inputs[0].var == "a" assert w.nodes[3].inputs[0].binding.promise.node_id == n1.id # Test conversion to flyte_idl and back w._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "fake", "faker", "fakest", "fakerest") w = _workflow_models.WorkflowTemplate.from_flyte_idl(w.to_flyte_idl()) assert w.interface.inputs[ "input_1"].type == primitives.Integer.to_flyte_literal_type() assert w.interface.inputs[ "input_2"].type == primitives.Integer.to_flyte_literal_type() assert w.nodes[0].inputs[0].var == "a" assert w.nodes[0].inputs[ 0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[0].inputs[0].binding.promise.var == "input_1" assert w.nodes[1].inputs[ 0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[1].inputs[0].binding.promise.var == "input_2" assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 assert w.nodes[3].inputs[0].var == "a" assert w.nodes[3].inputs[0].binding.promise.node_id == n1.id assert w.nodes[4].inputs[0].var == "a" assert w.nodes[4].inputs[0].binding.collection.bindings[ 0].promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[4].inputs[0].binding.collection.bindings[ 0].promise.var == "input_1" assert w.nodes[4].inputs[0].binding.collection.bindings[ 1].promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[4].inputs[0].binding.collection.bindings[ 1].promise.var == "input_2" assert w.nodes[4].inputs[0].binding.collection.bindings[ 2].promise.node_id == n3.id assert w.nodes[4].inputs[0].binding.collection.bindings[ 2].promise.var == "b" assert w.nodes[4].inputs[0].binding.collection.bindings[ 3].scalar.primitive.integer == 100 assert w.nodes[5].inputs[0].var == "a" assert w.nodes[5].inputs[0].binding.promise.node_id == n5.id assert w.nodes[5].inputs[0].binding.promise.var == "b" assert len(w.outputs) == 1 assert w.outputs[0].var == "a" assert w.outputs[0].binding.promise.var == "b" assert w.outputs[0].binding.promise.node_id == "n1"
def __post_init__(self): self._id = _identifier_model.Identifier(self.resource_type, self.project, self.domain, self.name, self.version)
def test_builtin_algorithm_training_job_task(): builtin_algorithm_training_job_task = SdkBuiltinAlgorithmTrainingJobTask( training_job_resource_config=TrainingJobResourceConfig( instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, input_content_type=InputContentType.TEXT_CSV, algorithm_name=AlgorithmName.XGBOOST, algorithm_version="0.72", ), ) builtin_algorithm_training_job_task._id = _identifier.Identifier( _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name", "my_version") assert isinstance(builtin_algorithm_training_job_task, SdkBuiltinAlgorithmTrainingJobTask) assert isinstance(builtin_algorithm_training_job_task, _sdk_task.SdkTask) assert builtin_algorithm_training_job_task.interface.inputs[ "train"].description == "" assert builtin_algorithm_training_job_task.interface.inputs[ "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType( format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, )) assert (builtin_algorithm_training_job_task.interface.inputs["train"].type == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type()) assert builtin_algorithm_training_job_task.interface.inputs[ "validation"].description == "" assert (builtin_algorithm_training_job_task.interface.inputs["validation"]. type == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type()) assert builtin_algorithm_training_job_task.interface.inputs[ "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType( format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, )) assert builtin_algorithm_training_job_task.interface.inputs[ "static_hyperparameters"].description == "" assert (builtin_algorithm_training_job_task.interface. inputs["static_hyperparameters"].type == _sdk_types.Types.Generic.to_flyte_literal_type()) assert builtin_algorithm_training_job_task.interface.outputs[ "model"].description == "" assert (builtin_algorithm_training_job_task.interface.outputs["model"].type == _sdk_types.Types.Blob.to_flyte_literal_type()) assert builtin_algorithm_training_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK assert builtin_algorithm_training_job_task.metadata.timeout == _datetime.timedelta( seconds=0) assert builtin_algorithm_training_job_task.metadata.deprecated_error_message == "" assert builtin_algorithm_training_job_task.metadata.discoverable is False assert builtin_algorithm_training_job_task.metadata.discovery_version == "" assert builtin_algorithm_training_job_task.metadata.retries.retries == 0 assert "metricDefinitions" not in builtin_algorithm_training_job_task.custom[ "algorithmSpecification"].keys() ParseDict( builtin_algorithm_training_job_task. custom["trainingJobResourceConfig"], _pb2_TrainingJobResourceConfig(), ) # fails the test if it cannot be parsed
def test_workflow_closure(): int_type = _types.LiteralType(_types.SimpleType.INTEGER) typed_interface = _interface.TypedInterface( {'a': _interface.Variable(int_type, "description1")}, { 'b': _interface.Variable(int_type, "description2"), 'c': _interface.Variable(int_type, "description3") }) b0 = _literals.Binding( 'a', _literals.BindingData(scalar=_literals.Scalar( primitive=_literals.Primitive(integer=5)))) b1 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'b'))) b2 = _literals.Binding( 'b', _literals.BindingData(promise=_types.OutputReference('my_node', 'c'))) node_metadata = _workflow.NodeMetadata(name='node1', timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) task_metadata = _task.TaskMetadata( True, _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), timedelta(days=1), _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!") cpu_resource = _task.Resources.ResourceEntry( _task.Resources.ResourceName.CPU, "1") resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource]) task = _task.TaskTemplate( _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version"), "python", task_metadata, typed_interface, { 'a': 1, 'b': { 'c': 2, 'd': 3 } }, container=_task.Container("my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {})) task_node = _workflow.TaskNode(task.id) node = _workflow.Node(id='my_node', metadata=node_metadata, inputs=[b0], upstream_node_ids=[], output_aliases=[], task_node=task_node) template = _workflow.WorkflowTemplate( id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version"), metadata=_workflow.WorkflowMetadata(), interface=typed_interface, nodes=[node], outputs=[b1, b2], ) obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task]) assert len(obj.tasks) == 1 obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2
MetricDefinition(name="Validation error", regex="validation:error") ], ), ) simple_xgboost_hpo_job_task = hpo_job_task.SdkSimpleHyperparameterTuningJobTask( training_job=builtin_algorithm_training_job_task2, max_number_of_training_jobs=10, max_parallel_training_jobs=5, cache_version="1", retries=2, cacheable=True, ) simple_xgboost_hpo_job_task._id = _identifier.Identifier( _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name", "my_version") def test_simple_hpo_job_task(): assert isinstance(simple_xgboost_hpo_job_task, SdkSimpleHyperparameterTuningJobTask) assert isinstance(simple_xgboost_hpo_job_task, _sdk_task.SdkTask) # Checking if the input of the underlying SdkTrainingJobTask has been embedded assert simple_xgboost_hpo_job_task.interface.inputs[ "train"].description == "" assert (simple_xgboost_hpo_job_task.interface.inputs["train"].type == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type()) assert simple_xgboost_hpo_job_task.interface.inputs[ "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType( format="csv",
def test_workflow_decorator(): @inputs(a=primitives.Integer) @outputs(b=primitives.Integer) @python_task def my_task(wf_params, a, b): b.set(a + 1) my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "propject", "domain", "my_task", "version") @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @python_task def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "propject", "domain", "my_list_task", "version") class my_workflow(object): input_1 = promise.Input("input_1", primitives.Integer) input_2 = promise.Input("input_2", primitives.Integer, default=5, help="Not required.") n1 = my_task(a=input_1) n2 = my_task(a=input_2) n3 = my_task(a=100) n4 = my_task(a=n1.outputs.b) n5 = my_list_task(a=[input_1, input_2, n3.outputs.b, 100]) n6 = my_list_task(a=n5.outputs.b) n1 >> n6 a = _local_workflow.Output("a", n1.outputs.b, sdk_type=primitives.Integer) w = _local_workflow.build_sdk_workflow_from_metaclass( my_workflow, on_failure=_workflow_models.WorkflowMetadata.OnFailurePolicy. FAIL_AFTER_EXECUTABLE_NODES_COMPLETE, ) assert w.should_create_default_launch_plan is True assert w.interface.inputs[ "input_1"].type == primitives.Integer.to_flyte_literal_type() assert w.interface.inputs[ "input_2"].type == primitives.Integer.to_flyte_literal_type() assert w.nodes[0].inputs[0].var == "a" assert w.nodes[0].inputs[ 0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[0].inputs[0].binding.promise.var == "input_1" assert w.nodes[1].inputs[ 0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[1].inputs[0].binding.promise.var == "input_2" assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 assert w.nodes[3].inputs[0].var == "a" assert w.nodes[3].inputs[0].binding.promise.node_id == "n1" # Test conversion to flyte_idl and back w.id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "fake", "faker", "fakest", "fakerest") w = _workflow_models.WorkflowTemplate.from_flyte_idl(w.to_flyte_idl()) assert w.interface.inputs[ "input_1"].type == primitives.Integer.to_flyte_literal_type() assert w.interface.inputs[ "input_2"].type == primitives.Integer.to_flyte_literal_type() assert w.nodes[0].inputs[0].var == "a" assert w.nodes[0].inputs[ 0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[0].inputs[0].binding.promise.var == "input_1" assert w.nodes[1].inputs[ 0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[1].inputs[0].binding.promise.var == "input_2" assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 assert w.nodes[3].inputs[0].var == "a" assert w.nodes[3].inputs[0].binding.promise.node_id == "n1" assert w.nodes[4].inputs[0].var == "a" assert w.nodes[4].inputs[0].binding.collection.bindings[ 0].promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[4].inputs[0].binding.collection.bindings[ 0].promise.var == "input_1" assert w.nodes[4].inputs[0].binding.collection.bindings[ 1].promise.node_id == constants.GLOBAL_INPUT_NODE_ID assert w.nodes[4].inputs[0].binding.collection.bindings[ 1].promise.var == "input_2" assert w.nodes[4].inputs[0].binding.collection.bindings[ 2].promise.node_id == "n3" assert w.nodes[4].inputs[0].binding.collection.bindings[ 2].promise.var == "b" assert w.nodes[4].inputs[0].binding.collection.bindings[ 3].scalar.primitive.integer == 100 assert w.nodes[5].inputs[0].var == "a" assert w.nodes[5].inputs[0].binding.promise.node_id == "n5" assert w.nodes[5].inputs[0].binding.promise.var == "b" assert len(w.outputs) == 1 assert w.outputs[0].var == "a" assert w.outputs[0].binding.promise.var == "b" assert w.outputs[0].binding.promise.node_id == "n1" assert (w.metadata.on_failure == _workflow_models.WorkflowMetadata. OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE)
def test_execution_spec(literal_value_pair): literal_value, _ = literal_value_pair obj = _execution.ExecutionSpec( _identifier.Identifier(_identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version"), _execution.ExecutionMetadata( _execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1), notifications=_execution.NotificationList([ _common_models.Notification( [_core_exec.WorkflowExecutionPhase.ABORTED], pager_duty=_common_models.PagerDutyNotification( recipients_email=["a", "b", "c"]), ) ]), ) assert obj.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN assert obj.launch_plan.domain == "domain" assert obj.launch_plan.project == "project" assert obj.launch_plan.name == "name" assert obj.launch_plan.version == "version" assert obj.metadata.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL assert obj.metadata.nesting == 1 assert obj.metadata.principal == "tester" assert obj.notifications.notifications[0].phases == [ _core_exec.WorkflowExecutionPhase.ABORTED ] assert obj.notifications.notifications[0].pager_duty.recipients_email == [ "a", "b", "c", ] assert obj.disable_all is None obj2 = _execution.ExecutionSpec.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN assert obj2.launch_plan.domain == "domain" assert obj2.launch_plan.project == "project" assert obj2.launch_plan.name == "name" assert obj2.launch_plan.version == "version" assert obj2.metadata.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL assert obj2.metadata.nesting == 1 assert obj2.metadata.principal == "tester" assert obj2.notifications.notifications[0].phases == [ _core_exec.WorkflowExecutionPhase.ABORTED ] assert obj2.notifications.notifications[0].pager_duty.recipients_email == [ "a", "b", "c", ] assert obj2.disable_all is None obj = _execution.ExecutionSpec( _identifier.Identifier(_identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version"), _execution.ExecutionMetadata( _execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1), disable_all=True, ) assert obj.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN assert obj.launch_plan.domain == "domain" assert obj.launch_plan.project == "project" assert obj.launch_plan.name == "name" assert obj.launch_plan.version == "version" assert obj.metadata.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL assert obj.metadata.nesting == 1 assert obj.metadata.principal == "tester" assert obj.notifications is None assert obj.disable_all is True obj2 = _execution.ExecutionSpec.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN assert obj2.launch_plan.domain == "domain" assert obj2.launch_plan.project == "project" assert obj2.launch_plan.name == "name" assert obj2.launch_plan.version == "version" assert obj2.metadata.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL assert obj2.metadata.nesting == 1 assert obj2.metadata.principal == "tester" assert obj2.notifications is None assert obj2.disable_all is True
def test_workflow_node(): @inputs(a=primitives.Integer) @outputs(b=primitives.Integer) @python_task() def my_task(wf_params, a, b): b.set(a + 1) my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_task", "version") @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @python_task def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_list_task", "version") input_list = [ promise.Input("required", primitives.Integer), promise.Input("not_required", primitives.Integer, default=5, help="Not required."), ] n1 = my_task(a=input_list[0]).assign_id_and_return("n1") n2 = my_task(a=input_list[1]).assign_id_and_return("n2") n3 = my_task(a=100).assign_id_and_return("n3") n4 = my_task(a=n1.outputs.b).assign_id_and_return("n4") n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100 ]).assign_id_and_return("n5") n6 = my_list_task(a=n5.outputs.b) nodes = [n1, n2, n3, n4, n5, n6] wf_out = [ _local_workflow.Output( "nested_out", [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], sdk_type=[[primitives.Integer]], ), _local_workflow.Output("scalar_out", n1.outputs.b, sdk_type=primitives.Integer), ] w = _local_workflow.SdkRunnableWorkflow.construct_from_class_definition( inputs=input_list, outputs=wf_out, nodes=nodes) # Test that required input isn't set with _pytest.raises(_user_exceptions.FlyteAssertion): w() # Test that positional args are rejected with _pytest.raises(_user_exceptions.FlyteAssertion): w(1, 2) # Test that type checking works with _pytest.raises(_user_exceptions.FlyteTypeException): w(required="abc", not_required=1) # Test that bad arg name is detected with _pytest.raises(_user_exceptions.FlyteAssertion): w(required=1, bad_arg=1) # Test default input is accounted for n = w(required=10) assert n.inputs[0].var == "not_required" assert n.inputs[0].binding.scalar.primitive.integer == 5 assert n.inputs[1].var == "required" assert n.inputs[1].binding.scalar.primitive.integer == 10 # Test default input is overridden n = w(required=10, not_required=50) assert n.inputs[0].var == "not_required" assert n.inputs[0].binding.scalar.primitive.integer == 50 assert n.inputs[1].var == "required" assert n.inputs[1].binding.scalar.primitive.integer == 10 # Test that workflow is saved in the node w.id = "fake" assert n.workflow_node.sub_workflow_ref == "fake" w.id = None # Test that outputs are promised n.assign_id_and_return("node-id*") # dns'ified assert n.outputs["scalar_out"].sdk_type.to_flyte_literal_type( ) == primitives.Integer.to_flyte_literal_type() assert n.outputs["scalar_out"].var == "scalar_out" assert n.outputs["scalar_out"].node_id == "node-id" assert (n.outputs["nested_out"].sdk_type.to_flyte_literal_type() == containers.List(containers.List( primitives.Integer)).to_flyte_literal_type()) assert n.outputs["nested_out"].var == "nested_out" assert n.outputs["nested_out"].node_id == "node-id"