def __init__(self, id, upstream_nodes, bindings, metadata, sdk_task=None, sdk_workflow=None, sdk_launch_plan=None, sdk_branch=None): """ :param Text id: A workflow-level unique identifier that identifies this node in the workflow. "inputs" and "outputs" are reserved node ids that cannot be used by other nodes. :param flytekit.models.core.workflow.NodeMetadata metadata: Extra metadata about the node. :param list[flytekit.models.literals.Binding] bindings: Specifies how to bind the underlying interface's inputs. All required inputs specified in the underlying interface must be fulfilled. :param list[SdkNode] upstream_nodes: Specifies execution dependencies for this node ensuring it will only get scheduled to run after all its upstream nodes have completed. This node will have an implicit dependency on any node that appears in inputs field. :param flytekit.common.tasks.task.SdkTask sdk_task: The task to execute in this node. :param flytekit.common.workflow.SdkWorkflow sdk_workflow: The workflow to execute in this node. :param flytekit.common.launch_plan.SdkLaunchPlan sdk_launch_plan: The launch plan to execute in this node. :param TODO sdk_branch: TODO """ non_none_entities = [ entity for entity in [sdk_workflow, sdk_branch, sdk_launch_plan, sdk_task] if entity is not None ] if len(non_none_entities) != 1: raise _user_exceptions.FlyteAssertion( "An SDK node must have one underlying entity specified at once. Received the following " "entities: {}".format(non_none_entities)) workflow_node = None if sdk_workflow is not None: workflow_node = _component_nodes.SdkWorkflowNode( sdk_workflow=sdk_workflow) elif sdk_launch_plan is not None: workflow_node = _component_nodes.SdkWorkflowNode( sdk_launch_plan=sdk_launch_plan) super(SdkNode, self).__init__( id=_dnsify(id) if id else None, metadata=metadata, inputs=bindings, upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], # TODO: Are aliases a thing in SDK nodes task_node=_component_nodes.SdkTaskNode(sdk_task) if sdk_task else None, workflow_node=workflow_node, branch_node=sdk_branch.target if sdk_branch else None) self._upstream = upstream_nodes self._executable_sdk_object = sdk_task or sdk_workflow or sdk_branch or sdk_launch_plan self._outputs = OutputParameterMapper( self._executable_sdk_object.interface.outputs, self)
def test_sdk_launch_plan_node(): @_tasks.inputs(a=_types.Types.Integer) @_tasks.outputs(b=_types.Types.Integer) @_tasks.python_task() def testy_test(wf_params, a, b): pass @_workflow.workflow_class class test_workflow(object): a = _workflow.Input(_types.Types.Integer) test = testy_test(a=1) b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer) lp = test_workflow.create_launch_plan() lp._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version") n = _component_nodes.SdkWorkflowNode(sdk_launch_plan=lp) assert n.launchplan_ref.project == "project" assert n.launchplan_ref.domain == "domain" assert n.launchplan_ref.name == "name" assert n.launchplan_ref.version == "version" # Test floating ID lp._id = _identifier.Identifier( _identifier.ResourceType.TASK, "new_project", "new_domain", "new_name", "new_version", ) assert n.launchplan_ref.project == "new_project" assert n.launchplan_ref.domain == "new_domain" assert n.launchplan_ref.name == "new_name" assert n.launchplan_ref.version == "new_version" # If you specify both, you should get an exception with _pytest.raises(_system_exceptions.FlyteSystemException): _component_nodes.SdkWorkflowNode(sdk_workflow=test_workflow, sdk_launch_plan=lp)