Exemplo n.º 1
0
    def __init__(self,
                 id,
                 upstream_nodes,
                 bindings,
                 metadata,
                 sdk_task=None,
                 sdk_workflow=None,
                 sdk_launch_plan=None,
                 sdk_branch=None):
        """
        :param Text id: A workflow-level unique identifier that identifies this node in the workflow. "inputs" and
            "outputs" are reserved node ids that cannot be used by other nodes.
        :param flytekit.models.core.workflow.NodeMetadata metadata: Extra metadata about the node.
        :param list[flytekit.models.literals.Binding] bindings: Specifies how to bind the underlying
            interface's inputs.  All required inputs specified in the underlying interface must be fulfilled.
        :param list[SdkNode] upstream_nodes: Specifies execution dependencies for this node ensuring it will
            only get scheduled to run after all its upstream nodes have completed. This node will have
            an implicit dependency on any node that appears in inputs field.
        :param flytekit.common.tasks.task.SdkTask sdk_task: The task to execute in this
            node.
        :param flytekit.common.workflow.SdkWorkflow sdk_workflow: The workflow to execute in this node.
        :param flytekit.common.launch_plan.SdkLaunchPlan sdk_launch_plan: The launch plan to execute in this
        node.
        :param TODO sdk_branch: TODO
        """
        non_none_entities = [
            entity for entity in
            [sdk_workflow, sdk_branch, sdk_launch_plan, sdk_task]
            if entity is not None
        ]
        if len(non_none_entities) != 1:
            raise _user_exceptions.FlyteAssertion(
                "An SDK node must have one underlying entity specified at once.  Received the following "
                "entities: {}".format(non_none_entities))

        workflow_node = None
        if sdk_workflow is not None:
            workflow_node = _component_nodes.SdkWorkflowNode(
                sdk_workflow=sdk_workflow)
        elif sdk_launch_plan is not None:
            workflow_node = _component_nodes.SdkWorkflowNode(
                sdk_launch_plan=sdk_launch_plan)

        super(SdkNode, self).__init__(
            id=_dnsify(id) if id else None,
            metadata=metadata,
            inputs=bindings,
            upstream_node_ids=[n.id for n in upstream_nodes],
            output_aliases=[],  # TODO: Are aliases a thing in SDK nodes
            task_node=_component_nodes.SdkTaskNode(sdk_task)
            if sdk_task else None,
            workflow_node=workflow_node,
            branch_node=sdk_branch.target if sdk_branch else None)
        self._upstream = upstream_nodes
        self._executable_sdk_object = sdk_task or sdk_workflow or sdk_branch or sdk_launch_plan
        self._outputs = OutputParameterMapper(
            self._executable_sdk_object.interface.outputs, self)
Exemplo n.º 2
0
def test_sdk_launch_plan_node():
    @_tasks.inputs(a=_types.Types.Integer)
    @_tasks.outputs(b=_types.Types.Integer)
    @_tasks.python_task()
    def testy_test(wf_params, a, b):
        pass

    @_workflow.workflow_class
    class test_workflow(object):
        a = _workflow.Input(_types.Types.Integer)
        test = testy_test(a=1)
        b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer)

    lp = test_workflow.create_launch_plan()

    lp._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project",
                                    "domain", "name", "version")
    n = _component_nodes.SdkWorkflowNode(sdk_launch_plan=lp)
    assert n.launchplan_ref.project == "project"
    assert n.launchplan_ref.domain == "domain"
    assert n.launchplan_ref.name == "name"
    assert n.launchplan_ref.version == "version"

    # Test floating ID
    lp._id = _identifier.Identifier(
        _identifier.ResourceType.TASK,
        "new_project",
        "new_domain",
        "new_name",
        "new_version",
    )
    assert n.launchplan_ref.project == "new_project"
    assert n.launchplan_ref.domain == "new_domain"
    assert n.launchplan_ref.name == "new_name"
    assert n.launchplan_ref.version == "new_version"

    # If you specify both, you should get an exception
    with _pytest.raises(_system_exceptions.FlyteSystemException):
        _component_nodes.SdkWorkflowNode(sdk_workflow=test_workflow,
                                         sdk_launch_plan=lp)