Ejemplo n.º 1
0
def test_execution_annotation_overrides(mock_client_factory):
    mock_client = MagicMock()
    mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier("xp", "xd", "xn"))
    mock_client_factory.return_value = mock_client

    m = MagicMock()
    type(m).id = PropertyMock(
        return_value=identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version")
    )

    annotations = _common_models.Annotations({"my": "annotation"})
    engine.FlyteLaunchPlan(m).launch(
        "xp",
        "xd",
        "xn",
        literals.LiteralMap({}),
        notification_overrides=[],
        annotation_overrides=annotations,
    )

    mock_client.create_execution.assert_called_once_with(
        "xp",
        "xd",
        "xn",
        _execution_models.ExecutionSpec(
            identifier.Identifier(
                identifier.ResourceType.LAUNCH_PLAN,
                "project",
                "domain",
                "name",
                "version",
            ),
            _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0),
            disable_all=True,
            annotations=annotations,
        ),
        literals.LiteralMap({}),
    )
Ejemplo n.º 2
0
def test_basic_workflow_promote(mock_task_fetch):
    # This section defines a sample workflow from a user
    @_sdk_tasks.inputs(a=_Types.Integer)
    @_sdk_tasks.outputs(b=_Types.Integer, c=_Types.Integer)
    @_sdk_tasks.python_task()
    def demo_task_for_promote(wf_params, a, b, c):
        b.set(a + 1)
        c.set(a + 2)

    @_sdk_workflow.workflow_class()
    class TestPromoteExampleWf(object):
        wf_input = _sdk_workflow.Input(_Types.Integer, required=True)
        my_task_node = demo_task_for_promote(a=wf_input)
        wf_output_b = _sdk_workflow.Output(my_task_node.outputs.b, sdk_type=_Types.Integer)
        wf_output_c = _sdk_workflow.Output(my_task_node.outputs.c, sdk_type=_Types.Integer)

    # This section uses the TaskTemplate stored in Admin to promote back to an Sdk Workflow
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    task_interface = _interface.TypedInterface(
        # inputs
        {'a': _interface.Variable(int_type, "description1")},
        # outputs
        {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        }
    )
    # Since the promotion of a workflow requires retrieving the task from Admin, we mock the SdkTask to return
    task_template = _task_model.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain",
                               "tests.flytekit.unit.common_tests.test_workflow_promote.demo_task_for_promote",
                               "version"),
        "python_container",
        get_sample_task_metadata(),
        task_interface,
        custom={},
        container=get_sample_container()
    )
    sdk_promoted_task = _task.SdkTask.promote_from_model(task_template)
    mock_task_fetch.return_value = sdk_promoted_task
    workflow_template = get_workflow_template()
    promoted_wf = _workflow_common.SdkWorkflow.promote_from_model(workflow_template)

    assert promoted_wf.interface.inputs["wf_input"] == TestPromoteExampleWf.interface.inputs["wf_input"]
    assert promoted_wf.interface.outputs["wf_output_b"] == TestPromoteExampleWf.interface.outputs["wf_output_b"]
    assert promoted_wf.interface.outputs["wf_output_c"] == TestPromoteExampleWf.interface.outputs["wf_output_c"]

    assert len(promoted_wf.nodes) == 1
    assert len(TestPromoteExampleWf.nodes) == 1
    assert promoted_wf.nodes[0].inputs[0] == TestPromoteExampleWf.nodes[0].inputs[0]
Ejemplo n.º 3
0
def activate_all_impl(project, domain, version, pkgs, ignore_schedules=False):
    # TODO: This should be a transaction to ensure all or none are updated
    # TODO: We should optionally allow deactivation of missing launch plans

    # Discover all launch plans by loading the modules
    for m, k, lp in iterate_registerable_entities_in_order(
            pkgs,
            include_entities={_SdkLaunchPlan},
            detect_unreferenced_entities=False):
        lp._id = _identifier.Identifier(
            _identifier.ResourceType.LAUNCH_PLAN, project, domain,
            _utils.fqdn(m.__name__, k, entity_type=lp.resource_type), version)
        if not (lp.is_scheduled and ignore_schedules):
            lp.update(_launch_plan_model.LaunchPlanState.ACTIVE)
Ejemplo n.º 4
0
def test_task_execution_identifier():
    task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version")
    wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name")
    node_exec_id = identifier.NodeExecutionIdentifier("node_id", wf_exec_id,)
    obj = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3)
    assert obj.retry_attempt == 3
    assert obj.task_id == task_id
    assert obj.node_execution_id == node_exec_id

    obj2 = identifier.TaskExecutionIdentifier.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
    assert obj2.retry_attempt == 3
    assert obj2.task_id == task_id
    assert obj2.node_execution_id == node_exec_id
Ejemplo n.º 5
0
def test_identifier():
    obj = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version")
    assert obj.project == "project"
    assert obj.domain == "domain"
    assert obj.name == "name"
    assert obj.version == "version"
    assert obj.resource_type == identifier.ResourceType.TASK

    obj2 = identifier.Identifier.from_flyte_idl(obj.to_flyte_idl())
    assert obj2 == obj
    assert obj2.project == "project"
    assert obj2.domain == "domain"
    assert obj2.name == "name"
    assert obj2.version == "version"
    assert obj2.resource_type == identifier.ResourceType.TASK
Ejemplo n.º 6
0
def test_flyte_workflow_integration(mock_url, mock_client_manager):
    mock_url.get.return_value = "localhost"
    admin_workflow = _workflow_models.Workflow(
        _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p1", "d1",
                               "n1", "v1"),
        _MagicMock(),
    )
    mock_client = _MagicMock()
    mock_client.list_workflows_paginated = _MagicMock(
        returnValue=([admin_workflow], ""))
    mock_client_manager.return_value.client = mock_client

    workflow = _workflow.FlyteWorkflow.fetch("p1", "d1", "n1", "v1")
    assert workflow.entity_type_text == "Workflow"
    assert workflow.id == admin_workflow.id
Ejemplo n.º 7
0
def test_flyte_task_fetch(mock_url, mock_client_manager):
    mock_url.get.return_value = "localhost"
    admin_task_v1 = _task_models.Task(
        _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"),
        _MagicMock(),
    )
    admin_task_v2 = _task_models.Task(
        _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v2"),
        _MagicMock(),
    )
    mock_client = _MagicMock()
    mock_client.list_tasks_paginated = _MagicMock(return_value=([admin_task_v2, admin_task_v1], ""))
    mock_client_manager.return_value.client = mock_client

    latest_task = _task.FlyteTask.fetch_latest("p1", "d1", "n1")
    task_v1 = _task.FlyteTask.fetch("p1", "d1", "n1", "v1")
    task_v2 = _task.FlyteTask.fetch("p1", "d1", "n1", "v2")
    assert task_v1.id == admin_task_v1.id
    assert task_v1.id != latest_task.id
    assert task_v2.id == latest_task.id == admin_task_v2.id

    for task in [task_v1, task_v2]:
        assert task.entity_type_text == "Task"
        assert task.resource_type == _identifier.ResourceType.TASK
Ejemplo n.º 8
0
def test__extract_files(load_mock):
    id = _core_identifier.Identifier(_core_identifier.ResourceType.TASK,
                                     'myproject', 'development', 'name', 'v')
    t = get_sample_task()
    with TemporaryConfiguration("",
                                internal_overrides={
                                    'image': 'myflyteimage:v123',
                                    'project': 'myflyteproject',
                                    'domain': 'development'
                                }):
        task_spec = t.serialize()

    load_mock.side_effect = [id.to_flyte_idl(), task_spec]
    new_id, entity = _main._extract_pair('a', 'b')
    assert new_id == id.to_flyte_idl()
    assert task_spec == entity
Ejemplo n.º 9
0
def test__extract_files(load_mock):
    id = _core_identifier.Identifier(_core_identifier.ResourceType.TASK,
                                     "myproject", "development", "name", "v")
    t = get_sample_task()
    with TemporaryConfiguration(
            "",
            internal_overrides={
                "image": "myflyteimage:v123",
                "project": "myflyteproject",
                "domain": "development"
            },
    ):
        task_spec = t.serialize()

    load_mock.side_effect = [id.to_flyte_idl(), task_spec]
    new_id, entity = _main._extract_pair("a", "b")
    assert new_id == id.to_flyte_idl()
    assert task_spec == entity
Ejemplo n.º 10
0
def get_serializable_task(
    entity_mapping: OrderedDict,
    settings: SerializationSettings,
    entity: FlyteLocalEntity,
    fast: bool,
) -> task_models.TaskSpec:
    task_id = _identifier_model.Identifier(
        _identifier_model.ResourceType.TASK,
        settings.project,
        settings.domain,
        entity.name,
        settings.version,
    )
    tt = task_models.TaskTemplate(
        id=task_id,
        type=entity.task_type,
        metadata=entity.metadata.to_taskmetadata_model(),
        interface=entity.interface,
        custom=entity.get_custom(settings),
        container=entity.get_container(settings),
        task_type_version=entity.task_type_version,
        security_context=entity.security_context,
        config=entity.get_config(settings),
        k8s_pod=entity.get_k8s_pod(settings),
    )

    # For fast registration, we'll need to muck with the command, but only for certain kinds of tasks. Specifically,
    # tasks that rely on user code defined in the container. This should be encapsulated by the auto container
    # parent class
    if fast and isinstance(entity, PythonAutoContainerTask):
        args = [
            "pyflyte-fast-execute",
            "--additional-distribution",
            "{{ .remote_package_path }}",
            "--dest-dir",
            "{{ .dest_dir }}",
            "--",
        ] + tt.container.args[:]

        del tt.container.args[:]
        tt.container.args.extend(args)

    return task_models.TaskSpec(template=tt)
Ejemplo n.º 11
0
def test_task_template__k8s_pod_target():
    int_type = types.LiteralType(types.SimpleType.INTEGER)
    obj = task.TaskTemplate(
        identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"),
        "python",
        task.TaskMetadata(
            False,
            task.RuntimeMetadata(1, "v", "f"),
            timedelta(days=1),
            literal_models.RetryStrategy(5),
            False,
            "1.0",
            "deprecated",
        ),
        interface_models.TypedInterface(
            # inputs
            {"a": interface_models.Variable(int_type, "description1")},
            # outputs
            {
                "b": interface_models.Variable(int_type, "description2"),
                "c": interface_models.Variable(int_type, "description3"),
            },
        ),
        {"a": 1, "b": {"c": 2, "d": 3}},
        config={"a": "b"},
        k8s_pod=task.K8sPod(
            metadata=task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"}),
            pod_spec={"str": "val", "int": 1},
        ),
    )
    assert obj.id.resource_type == identifier.ResourceType.TASK
    assert obj.id.project == "project"
    assert obj.id.domain == "domain"
    assert obj.id.name == "name"
    assert obj.id.version == "version"
    assert obj.type == "python"
    assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}}
    assert obj.k8s_pod.metadata == task.K8sObjectMetadata(labels={"label": "foo"}, annotations={"anno": "bar"})
    assert obj.k8s_pod.pod_spec == {"str": "val", "int": 1}
    assert text_format.MessageToString(obj.to_flyte_idl()) == text_format.MessageToString(
        task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl()
    )
    assert obj.config == {"a": "b"}
Ejemplo n.º 12
0
def test_old_style_role():
    identifier_model = identifier.Identifier(identifier.ResourceType.TASK,
                                             "project", "domain", "name",
                                             "version")

    s = schedule.Schedule("asdf", "1 3 4 5 6 7")
    launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(
        schedule=s, notifications=[])

    v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN),
                           "asdf asdf asdf")
    p = interface.Parameter(var=v)
    parameter_map = interface.ParameterMap({"ppp": p})

    fixed_inputs = literals.LiteralMap({
        "a":
        literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(
            integer=1)))
    })

    labels_model = common.Labels({})
    annotations_model = common.Annotations({"my": "annotation"})

    raw_data_output_config = common.RawOutputDataConfig("s3://bucket")

    old_role = _launch_plan_idl.Auth(
        kubernetes_service_account="my:service:account")

    old_style_spec = _launch_plan_idl.LaunchPlanSpec(
        workflow_id=identifier_model.to_flyte_idl(),
        entity_metadata=launch_plan_metadata_model.to_flyte_idl(),
        default_inputs=parameter_map.to_flyte_idl(),
        fixed_inputs=fixed_inputs.to_flyte_idl(),
        labels=labels_model.to_flyte_idl(),
        annotations=annotations_model.to_flyte_idl(),
        raw_output_data_config=raw_data_output_config.to_flyte_idl(),
        auth=old_role,
    )

    lp_spec = launch_plan.LaunchPlanSpec.from_flyte_idl(old_style_spec)

    assert lp_spec.auth_role.assumable_iam_role == "my:service:account"
 def serialize_to_model(
         self, settings: SerializationSettings) -> _task_model.TaskTemplate:
     # This doesn't get called from translator unfortunately. Will need to move the translator to use the model
     # objects directly first.
     # Note: This doesn't settle the issue of duplicate registrations. We'll need to figure that out somehow.
     # TODO: After new control plane classes are in, promote the template to a FlyteTask, so that authors of
     #  customized-container tasks have a familiar thing to work with.
     obj = _task_model.TaskTemplate(
         identifier_models.Identifier(identifier_models.ResourceType.TASK,
                                      settings.project, settings.domain,
                                      self.name, settings.version),
         self.task_type,
         self.metadata.to_taskmetadata_model(),
         self.interface,
         self.get_custom(settings),
         container=self.get_container(settings),
         config=self.get_config(settings),
     )
     self._task_template = obj
     return obj
Ejemplo n.º 14
0
def test_task_template(in_tuple):
    task_metadata, interfaces, resources = in_tuple
    obj = task.TaskTemplate(
        identifier.Identifier(identifier.ResourceType.TASK, "project",
                              "domain", "name", "version"),
        "python",
        task_metadata,
        interfaces,
        {
            "a": 1,
            "b": {
                "c": 2,
                "d": 3
            }
        },
        container=task.Container(
            "my_image",
            ["this", "is", "a", "cmd"],
            ["this", "is", "an", "arg"],
            resources,
            {"a": "b"},
            {"d": "e"},
        ),
        config={"a": "b"},
    )
    assert obj.id.resource_type == identifier.ResourceType.TASK
    assert obj.id.project == "project"
    assert obj.id.domain == "domain"
    assert obj.id.name == "name"
    assert obj.id.version == "version"
    assert obj.type == "python"
    assert obj.metadata == task_metadata
    assert obj.interface == interfaces
    assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}}
    assert obj.container.image == "my_image"
    assert obj.container.resources == resources
    assert text_format.MessageToString(
        obj.to_flyte_idl()) == text_format.MessageToString(
            task.TaskTemplate.from_flyte_idl(
                obj.to_flyte_idl()).to_flyte_idl())
    assert obj.config == {"a": "b"}
Ejemplo n.º 15
0
def test_task_template_security_context(sec_ctx):
    obj = task.TaskTemplate(
        identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"),
        "python",
        parameterizers.LIST_OF_TASK_METADATA[0],
        parameterizers.LIST_OF_INTERFACES[0],
        {"a": 1, "b": {"c": 2, "d": 3}},
        container=task.Container(
            "my_image",
            ["this", "is", "a", "cmd"],
            ["this", "is", "an", "arg"],
            parameterizers.LIST_OF_RESOURCES[0],
            {"a": "b"},
            {"d": "e"},
        ),
        security_context=sec_ctx,
    )
    assert obj.security_context == sec_ctx
    assert text_format.MessageToString(obj.to_flyte_idl()) == text_format.MessageToString(
        task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl()
    )
Ejemplo n.º 16
0
def get_serializable_launch_plan(
    entity_mapping: OrderedDict,
    settings: SerializationSettings,
    entity: LaunchPlan,
    fast: bool,
) -> _launch_plan_models.LaunchPlan:
    wf_spec = get_serializable(entity_mapping, settings, entity.workflow)

    lps = _launch_plan_models.LaunchPlanSpec(
        workflow_id=wf_spec.template.id,
        entity_metadata=_launch_plan_models.LaunchPlanMetadata(
            schedule=entity.schedule,
            notifications=entity.notifications,
        ),
        default_inputs=entity.parameters,
        fixed_inputs=entity.fixed_inputs,
        labels=entity.labels or _common_models.Labels({}),
        annotations=entity.annotations or _common_models.Annotations({}),
        auth_role=entity._auth_role or _common_models.AuthRole(),
        raw_output_data_config=entity.raw_output_data_config
        or _common_models.RawOutputDataConfig(""),
    )
    lp_id = _identifier_model.Identifier(
        resource_type=_identifier_model.ResourceType.LAUNCH_PLAN,
        project=settings.project,
        domain=settings.domain,
        name=entity.name,
        version=settings.version,
    )
    lp_model = _launch_plan_models.LaunchPlan(
        id=lp_id,
        spec=lps,
        closure=_launch_plan_models.LaunchPlanClosure(
            state=None,
            expected_inputs=interface_models.ParameterMap({}),
            expected_outputs=interface_models.VariableMap({}),
        ),
    )

    return lp_model
Ejemplo n.º 17
0
def test_promote_from_model():
    workflow_to_test = _workflow.workflow(
        {},
        inputs={
            'required_input': _workflow.Input(_types.Types.Integer),
            'default_input': _workflow.Input(_types.Types.Integer, default=5)
        })
    workflow_to_test._id = _identifier.Identifier(
        _identifier.ResourceType.WORKFLOW, "p", "d", "n", "v")
    lp = workflow_to_test.create_launch_plan(
        fixed_inputs={'required_input': 5},
        schedule=_schedules.CronSchedule("* * ? * * *"),
        role='what',
        labels=_common_models.Labels({"my": "label"}))

    with _pytest.raises(_user_exceptions.FlyteAssertion):
        _launch_plan.SdkRunnableLaunchPlan.from_flyte_idl(lp.to_flyte_idl())

    lp_from_spec = _launch_plan.SdkLaunchPlan.from_flyte_idl(lp.to_flyte_idl())
    assert not isinstance(lp_from_spec, _launch_plan.SdkRunnableLaunchPlan)
    assert isinstance(lp_from_spec, _launch_plan.SdkLaunchPlan)
    assert lp_from_spec == lp
Ejemplo n.º 18
0
    def initialize():
        """
        Re-initializes the context and erases the entire context
        """
        # This is supplied so that tasks that rely on Flyte provided param functionality do not fail when run locally
        default_execution_id = _identifier.WorkflowExecutionIdentifier(
            project="local", domain="local", name="local")

        cfg = Config.auto()
        # Ensure a local directory is available for users to work with.
        user_space_path = os.path.join(cfg.local_sandbox_path, "user_space")
        pathlib.Path(user_space_path).mkdir(parents=True, exist_ok=True)

        # Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users
        # are already acquainted with
        default_context = FlyteContext(
            file_access=default_local_file_access_provider)
        default_user_space_params = ExecutionParameters(
            execution_id=WorkflowExecutionIdentifier.promote_from_model(
                default_execution_id),
            task_id=_identifier.Identifier(_identifier.ResourceType.TASK,
                                           "local", "local", "local", "local"),
            execution_date=_datetime.datetime.utcnow(),
            stats=mock_stats.MockStats(),
            logging=user_space_logger,
            tmp_dir=user_space_path,
            raw_output_prefix=default_context.file_access._raw_output_prefix,
            decks=[],
        )

        default_context = default_context.with_execution_state(
            default_context.new_execution_state().with_params(
                user_space_params=default_user_space_params)).build()
        default_context.set_stackframe(
            s=FlyteContextManager.get_origin_stackframe())
        flyte_context_Var.set([default_context])
Ejemplo n.º 19
0
from flytekit.models import task as _task_models
from flytekit.models import types as _type_models
from flytekit.models.core import identifier as _identifier
from flytekit.sdk.tasks import inputs, outputs, python_task
from flytekit.sdk.types import Types


@inputs(in1=Types.Integer)
@outputs(out1=Types.String)
@python_task
def default_task(wf_params, in1, out1):
    pass


default_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                          "project", "domain", "name",
                                          "version")


def test_default_python_task():
    assert isinstance(default_task, _sdk_runnable.SdkRunnableTask)
    assert default_task.interface.inputs["in1"].description == ""
    assert default_task.interface.inputs[
        "in1"].type == _type_models.LiteralType(
            simple=_type_models.SimpleType.INTEGER)
    assert default_task.interface.outputs["out1"].description == ""
    assert default_task.interface.outputs[
        "out1"].type == _type_models.LiteralType(
            simple=_type_models.SimpleType.STRING)
    assert default_task.type == _common_constants.SdkTaskType.PYTHON_TASK
    assert default_task.task_function_name == "default_task"
Ejemplo n.º 20
0
    @inputs(in1=Types.Integer)
    @outputs(out1=Types.String)
    @sidecar_task(
        cpu_request='10',
        gpu_limit='2',
        environment={"foo": "bar"},
        pod_spec=get_pod_spec(),
        primary_container_name="a container",
    )
    def simple_sidecar_task(wf_params, in1, out1):
        pass


simple_sidecar_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                                 "project", "domain", "name",
                                                 "version")


def test_sidecar_task():
    assert isinstance(simple_sidecar_task, _sdk_task.SdkTask)
    assert isinstance(simple_sidecar_task, _sidecar_task.SdkSidecarTask)

    pod_spec = simple_sidecar_task.custom['podSpec']
    assert pod_spec['restartPolicy'] == 'OnFailure'
    assert len(pod_spec['containers']) == 2
    primary_container = pod_spec['containers'][0]
    assert primary_container['name'] == 'a container'
    assert primary_container['args'] == [
        'pyflyte-execute', '--task-module',
        'tests.flytekit.unit.sdk.tasks.test_sidecar_tasks', '--task-name',
Ejemplo n.º 21
0
from datetime import timedelta

from flytekit.models import interface as _interface
from flytekit.models import literals as _literals
from flytekit.models import types as _types
from flytekit.models.core import condition as _condition
from flytekit.models.core import identifier as _identifier
from flytekit.models.core import workflow as _workflow

_generic_id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW,
                                     "project", "domain", "name", "version")


def test_node_metadata():
    obj = _workflow.NodeMetadata(name="node1",
                                 timeout=timedelta(seconds=10),
                                 retries=_literals.RetryStrategy(0))
    assert obj.timeout.seconds == 10
    assert obj.retries.retries == 0
    obj2 = _workflow.NodeMetadata.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.timeout.seconds == 10
    assert obj2.retries.retries == 0


def test_alias():
    obj = _workflow.Alias(var="myvar", alias="myalias")
    assert obj.alias == "myalias"
    assert obj.var == "myvar"
    obj2 = _workflow.Alias.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
Ejemplo n.º 22
0
from itertools import product

import pytest
from google.protobuf import text_format

from flytekit.models import array_job as _array_job
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import literals as _literals
from flytekit.models import task as _task
from flytekit.models.core import identifier as _identifier
from flytekit.models.core import workflow as _workflow
from tests.flytekit.common import parameterizers

LIST_OF_DYNAMIC_TASKS = [
    _task.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "p", "d", "n",
                               "v"),
        "python",
        task_metadata,
        interfaces,
        _array_job.ArrayJob(2, 2, 2).to_dict(),
        container=_task.Container(
            "my_image",
            ["this", "is", "a", "cmd"],
            ["this", "is", "an", "arg"],
            resources,
            {"a": "b"},
            {"d": "e"},
        ),
    ) for task_metadata, interfaces, resources in product(
        parameterizers.LIST_OF_TASK_METADATA,
        parameterizers.LIST_OF_INTERFACES,
Ejemplo n.º 23
0
def test_workflow():
    @inputs(a=primitives.Integer)
    @outputs(b=primitives.Integer)
    @python_task()
    def my_task(wf_params, a, b):
        b.set(a + 1)

    my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                         "project", "domain", "my_task",
                                         "version")

    @inputs(a=[primitives.Integer])
    @outputs(b=[primitives.Integer])
    @python_task
    def my_list_task(wf_params, a, b):
        b.set([v + 1 for v in a])

    my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                              "project", "domain",
                                              "my_list_task", "version")

    input_list = [
        promise.Input("input_1", primitives.Integer),
        promise.Input("input_2",
                      primitives.Integer,
                      default=5,
                      help="Not required."),
    ]

    n1 = my_task(a=input_list[0]).assign_id_and_return("n1")
    n2 = my_task(a=input_list[1]).assign_id_and_return("n2")
    n3 = my_task(a=100).assign_id_and_return("n3")
    n4 = my_task(a=n1.outputs.b).assign_id_and_return("n4")
    n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100
                         ]).assign_id_and_return("n5")
    n6 = my_list_task(a=n5.outputs.b)
    n1 >> n6

    nodes = [n1, n2, n3, n4, n5, n6]

    w = _local_workflow.SdkRunnableWorkflow.construct_from_class_definition(
        inputs=input_list,
        outputs=[
            _local_workflow.Output("a",
                                   n1.outputs.b,
                                   sdk_type=primitives.Integer)
        ],
        nodes=nodes,
    )

    assert w.interface.inputs[
        "input_1"].type == primitives.Integer.to_flyte_literal_type()
    assert w.interface.inputs[
        "input_2"].type == primitives.Integer.to_flyte_literal_type()
    assert w.nodes[0].inputs[0].var == "a"
    assert w.nodes[0].inputs[
        0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[0].inputs[0].binding.promise.var == "input_1"
    assert w.nodes[1].inputs[
        0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[1].inputs[0].binding.promise.var == "input_2"
    assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100
    assert w.nodes[3].inputs[0].var == "a"
    assert w.nodes[3].inputs[0].binding.promise.node_id == n1.id

    # Test conversion to flyte_idl and back
    w._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "fake",
                                   "faker", "fakest", "fakerest")
    w = _workflow_models.WorkflowTemplate.from_flyte_idl(w.to_flyte_idl())
    assert w.interface.inputs[
        "input_1"].type == primitives.Integer.to_flyte_literal_type()
    assert w.interface.inputs[
        "input_2"].type == primitives.Integer.to_flyte_literal_type()
    assert w.nodes[0].inputs[0].var == "a"
    assert w.nodes[0].inputs[
        0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[0].inputs[0].binding.promise.var == "input_1"
    assert w.nodes[1].inputs[
        0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[1].inputs[0].binding.promise.var == "input_2"
    assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100
    assert w.nodes[3].inputs[0].var == "a"
    assert w.nodes[3].inputs[0].binding.promise.node_id == n1.id
    assert w.nodes[4].inputs[0].var == "a"
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        0].promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        0].promise.var == "input_1"
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        1].promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        1].promise.var == "input_2"
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        2].promise.node_id == n3.id
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        2].promise.var == "b"
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        3].scalar.primitive.integer == 100
    assert w.nodes[5].inputs[0].var == "a"
    assert w.nodes[5].inputs[0].binding.promise.node_id == n5.id
    assert w.nodes[5].inputs[0].binding.promise.var == "b"

    assert len(w.outputs) == 1
    assert w.outputs[0].var == "a"
    assert w.outputs[0].binding.promise.var == "b"
    assert w.outputs[0].binding.promise.node_id == "n1"
Ejemplo n.º 24
0
 def __post_init__(self):
     self._id = _identifier_model.Identifier(self.resource_type,
                                             self.project, self.domain,
                                             self.name, self.version)
Ejemplo n.º 25
0
def test_builtin_algorithm_training_job_task():
    builtin_algorithm_training_job_task = SdkBuiltinAlgorithmTrainingJobTask(
        training_job_resource_config=TrainingJobResourceConfig(
            instance_type="ml.m4.xlarge",
            instance_count=1,
            volume_size_in_gb=25,
        ),
        algorithm_specification=AlgorithmSpecification(
            input_mode=InputMode.FILE,
            input_content_type=InputContentType.TEXT_CSV,
            algorithm_name=AlgorithmName.XGBOOST,
            algorithm_version="0.72",
        ),
    )

    builtin_algorithm_training_job_task._id = _identifier.Identifier(
        _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name",
        "my_version")
    assert isinstance(builtin_algorithm_training_job_task,
                      SdkBuiltinAlgorithmTrainingJobTask)
    assert isinstance(builtin_algorithm_training_job_task, _sdk_task.SdkTask)
    assert builtin_algorithm_training_job_task.interface.inputs[
        "train"].description == ""
    assert builtin_algorithm_training_job_task.interface.inputs[
        "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType(
            format="csv",
            dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
        ))
    assert (builtin_algorithm_training_job_task.interface.inputs["train"].type
            == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type())
    assert builtin_algorithm_training_job_task.interface.inputs[
        "validation"].description == ""
    assert (builtin_algorithm_training_job_task.interface.inputs["validation"].
            type == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type())
    assert builtin_algorithm_training_job_task.interface.inputs[
        "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType(
            format="csv",
            dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
        ))
    assert builtin_algorithm_training_job_task.interface.inputs[
        "static_hyperparameters"].description == ""
    assert (builtin_algorithm_training_job_task.interface.
            inputs["static_hyperparameters"].type ==
            _sdk_types.Types.Generic.to_flyte_literal_type())
    assert builtin_algorithm_training_job_task.interface.outputs[
        "model"].description == ""
    assert (builtin_algorithm_training_job_task.interface.outputs["model"].type
            == _sdk_types.Types.Blob.to_flyte_literal_type())
    assert builtin_algorithm_training_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK
    assert builtin_algorithm_training_job_task.metadata.timeout == _datetime.timedelta(
        seconds=0)
    assert builtin_algorithm_training_job_task.metadata.deprecated_error_message == ""
    assert builtin_algorithm_training_job_task.metadata.discoverable is False
    assert builtin_algorithm_training_job_task.metadata.discovery_version == ""
    assert builtin_algorithm_training_job_task.metadata.retries.retries == 0
    assert "metricDefinitions" not in builtin_algorithm_training_job_task.custom[
        "algorithmSpecification"].keys()

    ParseDict(
        builtin_algorithm_training_job_task.
        custom["trainingJobResourceConfig"],
        _pb2_TrainingJobResourceConfig(),
    )  # fails the test if it cannot be parsed
Ejemplo n.º 26
0
def test_workflow_closure():
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })

    b0 = _literals.Binding(
        'a',
        _literals.BindingData(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=5))))
    b1 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'b')))
    b2 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'c')))

    node_metadata = _workflow.NodeMetadata(name='node1',
                                           timeout=timedelta(seconds=10),
                                           retries=_literals.RetryStrategy(0))

    task_metadata = _task.TaskMetadata(
        True,
        _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                              "1.0.0", "python"), timedelta(days=1),
        _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!")

    cpu_resource = _task.Resources.ResourceEntry(
        _task.Resources.ResourceName.CPU, "1")
    resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])

    task = _task.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "project",
                               "domain", "name", "version"),
        "python",
        task_metadata,
        typed_interface, {
            'a': 1,
            'b': {
                'c': 2,
                'd': 3
            }
        },
        container=_task.Container("my_image", ["this", "is", "a", "cmd"],
                                  ["this", "is", "an", "arg"], resources, {},
                                  {}))

    task_node = _workflow.TaskNode(task.id)
    node = _workflow.Node(id='my_node',
                          metadata=node_metadata,
                          inputs=[b0],
                          upstream_node_ids=[],
                          output_aliases=[],
                          task_node=task_node)

    template = _workflow.WorkflowTemplate(
        id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project",
                                  "domain", "name", "version"),
        metadata=_workflow.WorkflowMetadata(),
        interface=typed_interface,
        nodes=[node],
        outputs=[b1, b2],
    )

    obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task])
    assert len(obj.tasks) == 1

    obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
Ejemplo n.º 27
0
            MetricDefinition(name="Validation error", regex="validation:error")
        ],
    ),
)

simple_xgboost_hpo_job_task = hpo_job_task.SdkSimpleHyperparameterTuningJobTask(
    training_job=builtin_algorithm_training_job_task2,
    max_number_of_training_jobs=10,
    max_parallel_training_jobs=5,
    cache_version="1",
    retries=2,
    cacheable=True,
)

simple_xgboost_hpo_job_task._id = _identifier.Identifier(
    _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name",
    "my_version")


def test_simple_hpo_job_task():
    assert isinstance(simple_xgboost_hpo_job_task,
                      SdkSimpleHyperparameterTuningJobTask)
    assert isinstance(simple_xgboost_hpo_job_task, _sdk_task.SdkTask)
    # Checking if the input of the underlying SdkTrainingJobTask has been embedded
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "train"].description == ""
    assert (simple_xgboost_hpo_job_task.interface.inputs["train"].type ==
            _sdk_types.Types.MultiPartCSV.to_flyte_literal_type())
    assert simple_xgboost_hpo_job_task.interface.inputs[
        "train"].type == _idl_types.LiteralType(blob=_core_types.BlobType(
            format="csv",
Ejemplo n.º 28
0
def test_workflow_decorator():
    @inputs(a=primitives.Integer)
    @outputs(b=primitives.Integer)
    @python_task
    def my_task(wf_params, a, b):
        b.set(a + 1)

    my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                         "propject", "domain", "my_task",
                                         "version")

    @inputs(a=[primitives.Integer])
    @outputs(b=[primitives.Integer])
    @python_task
    def my_list_task(wf_params, a, b):
        b.set([v + 1 for v in a])

    my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                              "propject", "domain",
                                              "my_list_task", "version")

    class my_workflow(object):
        input_1 = promise.Input("input_1", primitives.Integer)
        input_2 = promise.Input("input_2",
                                primitives.Integer,
                                default=5,
                                help="Not required.")
        n1 = my_task(a=input_1)
        n2 = my_task(a=input_2)
        n3 = my_task(a=100)
        n4 = my_task(a=n1.outputs.b)
        n5 = my_list_task(a=[input_1, input_2, n3.outputs.b, 100])
        n6 = my_list_task(a=n5.outputs.b)
        n1 >> n6
        a = _local_workflow.Output("a",
                                   n1.outputs.b,
                                   sdk_type=primitives.Integer)

    w = _local_workflow.build_sdk_workflow_from_metaclass(
        my_workflow,
        on_failure=_workflow_models.WorkflowMetadata.OnFailurePolicy.
        FAIL_AFTER_EXECUTABLE_NODES_COMPLETE,
    )

    assert w.should_create_default_launch_plan is True

    assert w.interface.inputs[
        "input_1"].type == primitives.Integer.to_flyte_literal_type()
    assert w.interface.inputs[
        "input_2"].type == primitives.Integer.to_flyte_literal_type()
    assert w.nodes[0].inputs[0].var == "a"
    assert w.nodes[0].inputs[
        0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[0].inputs[0].binding.promise.var == "input_1"
    assert w.nodes[1].inputs[
        0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[1].inputs[0].binding.promise.var == "input_2"
    assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100
    assert w.nodes[3].inputs[0].var == "a"
    assert w.nodes[3].inputs[0].binding.promise.node_id == "n1"

    # Test conversion to flyte_idl and back
    w.id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "fake",
                                  "faker", "fakest", "fakerest")
    w = _workflow_models.WorkflowTemplate.from_flyte_idl(w.to_flyte_idl())
    assert w.interface.inputs[
        "input_1"].type == primitives.Integer.to_flyte_literal_type()
    assert w.interface.inputs[
        "input_2"].type == primitives.Integer.to_flyte_literal_type()
    assert w.nodes[0].inputs[0].var == "a"
    assert w.nodes[0].inputs[
        0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[0].inputs[0].binding.promise.var == "input_1"
    assert w.nodes[1].inputs[
        0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[1].inputs[0].binding.promise.var == "input_2"
    assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100
    assert w.nodes[3].inputs[0].var == "a"
    assert w.nodes[3].inputs[0].binding.promise.node_id == "n1"
    assert w.nodes[4].inputs[0].var == "a"
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        0].promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        0].promise.var == "input_1"
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        1].promise.node_id == constants.GLOBAL_INPUT_NODE_ID
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        1].promise.var == "input_2"
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        2].promise.node_id == "n3"
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        2].promise.var == "b"
    assert w.nodes[4].inputs[0].binding.collection.bindings[
        3].scalar.primitive.integer == 100
    assert w.nodes[5].inputs[0].var == "a"
    assert w.nodes[5].inputs[0].binding.promise.node_id == "n5"
    assert w.nodes[5].inputs[0].binding.promise.var == "b"

    assert len(w.outputs) == 1
    assert w.outputs[0].var == "a"
    assert w.outputs[0].binding.promise.var == "b"
    assert w.outputs[0].binding.promise.node_id == "n1"
    assert (w.metadata.on_failure == _workflow_models.WorkflowMetadata.
            OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE)
Ejemplo n.º 29
0
def test_execution_spec(literal_value_pair):
    literal_value, _ = literal_value_pair

    obj = _execution.ExecutionSpec(
        _identifier.Identifier(_identifier.ResourceType.LAUNCH_PLAN, "project",
                               "domain", "name", "version"),
        _execution.ExecutionMetadata(
            _execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1),
        notifications=_execution.NotificationList([
            _common_models.Notification(
                [_core_exec.WorkflowExecutionPhase.ABORTED],
                pager_duty=_common_models.PagerDutyNotification(
                    recipients_email=["a", "b", "c"]),
            )
        ]),
    )
    assert obj.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN
    assert obj.launch_plan.domain == "domain"
    assert obj.launch_plan.project == "project"
    assert obj.launch_plan.name == "name"
    assert obj.launch_plan.version == "version"
    assert obj.metadata.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL
    assert obj.metadata.nesting == 1
    assert obj.metadata.principal == "tester"
    assert obj.notifications.notifications[0].phases == [
        _core_exec.WorkflowExecutionPhase.ABORTED
    ]
    assert obj.notifications.notifications[0].pager_duty.recipients_email == [
        "a",
        "b",
        "c",
    ]
    assert obj.disable_all is None

    obj2 = _execution.ExecutionSpec.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN
    assert obj2.launch_plan.domain == "domain"
    assert obj2.launch_plan.project == "project"
    assert obj2.launch_plan.name == "name"
    assert obj2.launch_plan.version == "version"
    assert obj2.metadata.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL
    assert obj2.metadata.nesting == 1
    assert obj2.metadata.principal == "tester"
    assert obj2.notifications.notifications[0].phases == [
        _core_exec.WorkflowExecutionPhase.ABORTED
    ]
    assert obj2.notifications.notifications[0].pager_duty.recipients_email == [
        "a",
        "b",
        "c",
    ]
    assert obj2.disable_all is None

    obj = _execution.ExecutionSpec(
        _identifier.Identifier(_identifier.ResourceType.LAUNCH_PLAN, "project",
                               "domain", "name", "version"),
        _execution.ExecutionMetadata(
            _execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1),
        disable_all=True,
    )
    assert obj.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN
    assert obj.launch_plan.domain == "domain"
    assert obj.launch_plan.project == "project"
    assert obj.launch_plan.name == "name"
    assert obj.launch_plan.version == "version"
    assert obj.metadata.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL
    assert obj.metadata.nesting == 1
    assert obj.metadata.principal == "tester"
    assert obj.notifications is None
    assert obj.disable_all is True

    obj2 = _execution.ExecutionSpec.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN
    assert obj2.launch_plan.domain == "domain"
    assert obj2.launch_plan.project == "project"
    assert obj2.launch_plan.name == "name"
    assert obj2.launch_plan.version == "version"
    assert obj2.metadata.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL
    assert obj2.metadata.nesting == 1
    assert obj2.metadata.principal == "tester"
    assert obj2.notifications is None
    assert obj2.disable_all is True
Ejemplo n.º 30
0
def test_workflow_node():
    @inputs(a=primitives.Integer)
    @outputs(b=primitives.Integer)
    @python_task()
    def my_task(wf_params, a, b):
        b.set(a + 1)

    my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                         "project", "domain", "my_task",
                                         "version")

    @inputs(a=[primitives.Integer])
    @outputs(b=[primitives.Integer])
    @python_task
    def my_list_task(wf_params, a, b):
        b.set([v + 1 for v in a])

    my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK,
                                              "project", "domain",
                                              "my_list_task", "version")

    input_list = [
        promise.Input("required", primitives.Integer),
        promise.Input("not_required",
                      primitives.Integer,
                      default=5,
                      help="Not required."),
    ]

    n1 = my_task(a=input_list[0]).assign_id_and_return("n1")
    n2 = my_task(a=input_list[1]).assign_id_and_return("n2")
    n3 = my_task(a=100).assign_id_and_return("n3")
    n4 = my_task(a=n1.outputs.b).assign_id_and_return("n4")
    n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100
                         ]).assign_id_and_return("n5")
    n6 = my_list_task(a=n5.outputs.b)

    nodes = [n1, n2, n3, n4, n5, n6]

    wf_out = [
        _local_workflow.Output(
            "nested_out",
            [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]],
            sdk_type=[[primitives.Integer]],
        ),
        _local_workflow.Output("scalar_out",
                               n1.outputs.b,
                               sdk_type=primitives.Integer),
    ]

    w = _local_workflow.SdkRunnableWorkflow.construct_from_class_definition(
        inputs=input_list, outputs=wf_out, nodes=nodes)

    # Test that required input isn't set
    with _pytest.raises(_user_exceptions.FlyteAssertion):
        w()

    # Test that positional args are rejected
    with _pytest.raises(_user_exceptions.FlyteAssertion):
        w(1, 2)

    # Test that type checking works
    with _pytest.raises(_user_exceptions.FlyteTypeException):
        w(required="abc", not_required=1)

    # Test that bad arg name is detected
    with _pytest.raises(_user_exceptions.FlyteAssertion):
        w(required=1, bad_arg=1)

    # Test default input is accounted for
    n = w(required=10)
    assert n.inputs[0].var == "not_required"
    assert n.inputs[0].binding.scalar.primitive.integer == 5
    assert n.inputs[1].var == "required"
    assert n.inputs[1].binding.scalar.primitive.integer == 10

    # Test default input is overridden
    n = w(required=10, not_required=50)
    assert n.inputs[0].var == "not_required"
    assert n.inputs[0].binding.scalar.primitive.integer == 50
    assert n.inputs[1].var == "required"
    assert n.inputs[1].binding.scalar.primitive.integer == 10

    # Test that workflow is saved in the node
    w.id = "fake"
    assert n.workflow_node.sub_workflow_ref == "fake"
    w.id = None

    # Test that outputs are promised
    n.assign_id_and_return("node-id*")  # dns'ified
    assert n.outputs["scalar_out"].sdk_type.to_flyte_literal_type(
    ) == primitives.Integer.to_flyte_literal_type()
    assert n.outputs["scalar_out"].var == "scalar_out"
    assert n.outputs["scalar_out"].node_id == "node-id"

    assert (n.outputs["nested_out"].sdk_type.to_flyte_literal_type() ==
            containers.List(containers.List(
                primitives.Integer)).to_flyte_literal_type())
    assert n.outputs["nested_out"].var == "nested_out"
    assert n.outputs["nested_out"].node_id == "node-id"