コード例 #1
0
ファイル: base_task.py プロジェクト: fediazgon/flytekit
 def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
     return _workflow_model.NodeMetadata(
         name=f"{self.__module__}.{self.name}",
         timeout=self.metadata.timeout,
         retries=self.metadata.retry_strategy,
         interruptible=self.metadata.interruptible,
     )
コード例 #2
0
    def __call__(self, *args, **input_map):
        """
        :param list[T] args: Do not specify.  Kwargs only are supported for this function.
        :param dict[Text,T] input_map: Map of inputs.  Can be statically defined or OutputReference links.
        :rtype: flytekit.common.nodes.SdkNode
        """
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a launchplan as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args)))

        # Take the default values from the launch plan
        default_inputs = {
            k: v.sdk_default
            for k, v in _six.iteritems(self.default_inputs.parameters)
            if not v.required
        }
        default_inputs.update(input_map)

        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(
            default_inputs)

        return _nodes.SdkNode(
            id=None,
            metadata=_workflow_models.NodeMetadata(
                "", _datetime.timedelta(), _literal_models.RetryStrategy(0)),
            bindings=sorted(bindings, key=lambda b: b.var),
            upstream_nodes=upstream_nodes,
            sdk_launch_plan=self,
        )
コード例 #3
0
ファイル: workflow.py プロジェクト: jbrambleDC/flytekit
    def __call__(self, *args, **input_map):
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a workflow as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args))
            )

        # Take the default values from the Inputs
        compiled_inputs = {
            v.name: v.sdk_default
            for v in self.user_inputs if not v.sdk_required
        }
        compiled_inputs.update(input_map)

        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(compiled_inputs)

        node = _nodes.SdkNode(
            id=None,
            metadata=_workflow_models.NodeMetadata("placeholder", _datetime.timedelta(),
                                                   _literal_models.RetryStrategy(0)),
            upstream_nodes=upstream_nodes,
            bindings=sorted(bindings, key=lambda b: b.var),
            sdk_workflow=self
        )
        return node
コード例 #4
0
    def __call__(self, *args, **input_map):
        """
        :param list[T] args: Do not specify.  Kwargs only are supported for this function.
        :param dict[str, T] input_map: Map of inputs.  Can be statically defined or OutputReference links.
        :rtype: flytekit.common.nodes.SdkNode
        """
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a task as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args)))

        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(
            input_map)

        # TODO: Remove DEADBEEF
        # One thing to note - this function is not overloaded at the SdkRunnableTask layer, which means 'self' here
        # will sometimes refer to an object that can be executed locally, and other times will refer to something
        # that cannot (ie a pure SdkTask object, fetched from Admin for instance).
        return _nodes.SdkNode(
            id=None,
            metadata=_workflow_model.NodeMetadata(
                "DEADBEEF",
                self.metadata.timeout,
                self.metadata.retries,
                self.metadata.interruptible,
            ),
            bindings=sorted(bindings, key=lambda b: b.var),
            upstream_nodes=upstream_nodes,
            sdk_task=self,
        )
コード例 #5
0
def test_node_metadata():
    obj = _workflow.NodeMetadata(name='node1', timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0))
    assert obj.timeout.seconds == 10
    assert obj.retries.retries == 0
    obj2 = _workflow.NodeMetadata.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
    assert obj2.timeout.seconds == 10
    assert obj2.retries.retries == 0
コード例 #6
0
def get_sample_node_metadata(node_id):
    """
    :param Text node_id:
    :rtype: flytekit.models.core.workflow.NodeMetadata
    """

    return _workflow_model.NodeMetadata(name=node_id,
                                        timeout=timedelta(seconds=10),
                                        retries=_literals.RetryStrategy(0))
コード例 #7
0
 def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
     """
     Used when constructing the node that encapsulates this task as part of a broader workflow definition.
     """
     return _workflow_model.NodeMetadata(
         name=extract_obj_name(self.name),
         timeout=self.metadata.timeout,
         retries=self.metadata.retry_strategy,
         interruptible=self.metadata.interruptible,
     )
コード例 #8
0
    def end_branch(self) -> Union[Condition, Promise]:
        """
        This should be invoked after every branch has been visited
        """
        ctx = FlyteContext.current_context()
        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            """
            In case of Local workflow execution, we should first mark the branch as complete, then
            Then we first check for if this is the last case,
            In case this is the last case, we return the output from the selected case - A case should always
            be selected (see start_branch)
            If this is not the last case, we should return the condition so that further chaining can be done
            """
            # Let us mark the execution state as complete
            ctx.execution_state.branch_complete()
            if self._last_case:
                ctx.execution_state.exit_conditional_section()
                if self._selected_case.output_promise is None and self._selected_case.err is None:
                    raise AssertionError("Bad conditional statements, did not resolve in a promise")
                elif self._selected_case.output_promise is not None:
                    return self._selected_case.output_promise
                raise ValueError(self._selected_case.err)
            return self._condition
        elif ctx.compilation_state:
            ########
            # COMPILATION MODE
            """
            In case this is not local workflow execution then, we should check if this is the last case.
            If so then return the promise, else return the condition
            """
            if self._last_case:
                ctx.compilation_state.exit_conditional_section()
                # branch_nodes = ctx.compilation_state.nodes
                node, promises = to_branch_node(self._name, self)
                # Verify branch_nodes == nodes in bn
                bindings: typing.List[Binding] = []
                upstream_nodes = set()
                for p in promises:
                    if not p.is_ready:
                        bindings.append(Binding(var=p.var, binding=BindingData(promise=p.ref)))
                        upstream_nodes.add(p.ref.node)

                n = Node(
                    id=f"{ctx.compilation_state.prefix}node-{len(ctx.compilation_state.nodes)}",
                    metadata=_core_wf.NodeMetadata(self._name, timeout=datetime.timedelta(), retries=RetryStrategy(0)),
                    bindings=sorted(bindings, key=lambda b: b.var),
                    upstream_nodes=list(upstream_nodes),  # type: ignore
                    flyte_entity=node,
                )
                ctx.compilation_state.add_node(n)
                return self._compute_outputs(n)
            return self._condition

        raise AssertionError("Branches can only be invoked within a workflow context!")
コード例 #9
0
def _create_hive_job_node(name, hive_job, metadata):
    """
    :param Text name:
    :param _qubole.QuboleHiveJob hive_job: Hive job spec
    :param flytekit.models.task.TaskMetadata metadata: This contains information needed at runtime to determine
        behavior such as whether or not outputs are discoverable, timeouts, and retries.
    :rtype: _nodes.SdkNode:
    """
    return _nodes.SdkNode(id=_six.text_type(_uuid.uuid4()),
                          upstream_nodes=[],
                          bindings=[],
                          metadata=_workflow_model.NodeMetadata(
                              name, metadata.timeout,
                              _literal_models.RetryStrategy(0)),
                          sdk_task=SdkHiveJob(hive_job, metadata))
コード例 #10
0
def test_sdk_node_from_lp():
    @_tasks.inputs(a=_types.Types.Integer)
    @_tasks.outputs(b=_types.Types.Integer)
    @_tasks.python_task()
    def testy_test(wf_params, a, b):
        pass

    @_workflow.workflow_class
    class test_workflow(object):
        a = _workflow.Input(_types.Types.Integer)
        test = testy_test(a=a)
        b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer)

    lp = test_workflow.create_launch_plan()

    n1 = _nodes.SdkNode(
        "n1",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_launch_plan=lp,
    )

    assert n1.id == "n1"
    assert len(n1.inputs) == 1
    assert n1.inputs[0].var == "a"
    assert n1.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n1.outputs) == 1
    assert "b" in n1.outputs
    assert n1.outputs["b"].node_id == "n1"
    assert n1.outputs["b"].var == "b"
    assert n1.outputs["b"].sdk_node == n1
    assert n1.outputs["b"].sdk_type == _types.Types.Integer
    assert n1.metadata.name == "abc"
    assert n1.metadata.retries.retries == 3
    assert len(n1.upstream_nodes) == 0
    assert len(n1.upstream_node_ids) == 0
    assert len(n1.output_aliases) == 0
コード例 #11
0
ファイル: workflow.py プロジェクト: vglocus/flytekit
    def __call__(self, *args, **input_map):
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a workflow as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args)))
        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(
            input_map)

        node = _nodes.SdkNode(
            id=None,
            metadata=_workflow_models.NodeMetadata(
                "placeholder", _datetime.timedelta(),
                _literal_models.RetryStrategy(0)),
            upstream_nodes=upstream_nodes,
            bindings=sorted(bindings, key=lambda b: b.var),
            sdk_workflow=self,
        )
        return node
コード例 #12
0
def test_future_task_document(task):
    rs = _literals.RetryStrategy(0)
    nm = _workflow.NodeMetadata('node-name', _timedelta(minutes=10), rs)
    n = _workflow.Node(id="id",
                       metadata=nm,
                       inputs=[],
                       upstream_node_ids=[],
                       output_aliases=[],
                       task_node=_workflow.TaskNode(task.id))
    n.to_flyte_idl()
    doc = _dynamic_job.DynamicJobSpec(
        tasks=[task],
        nodes=[n],
        min_successes=1,
        outputs=[_literals.Binding("var", _literals.BindingData())],
        subworkflows=[])
    assert text_format.MessageToString(
        doc.to_flyte_idl()) == text_format.MessageToString(
            _dynamic_job.DynamicJobSpec.from_flyte_idl(
                doc.to_flyte_idl()).to_flyte_idl())
コード例 #13
0
    def end_branch(
        self
    ) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidPromise]]:
        """
        This should be invoked after every branch has been visited.
        In case this is not local workflow execution then, we should check if this is the last case.
        If so then return the promise, else return the condition
        """
        if self._last_case:
            # We have completed the conditional section, lets pop off the branch context
            FlyteContextManager.pop_context()
            ctx = FlyteContextManager.current_context()
            # Question: This is commented out because we don't need it? Nodes created in the conditional
            #   compilation state are captured in the to_case_block? Always?
            #   Is this still true of nested conditionals? Is that why propeller compiler is complaining?
            # branch_nodes = ctx.compilation_state.nodes
            node, promises = to_branch_node(self._name, self)
            # Verify branch_nodes == nodes in bn
            bindings: typing.List[Binding] = []
            upstream_nodes = set()
            for p in promises:
                if not p.is_ready:
                    bindings.append(
                        Binding(var=p.var, binding=BindingData(promise=p.ref)))
                    upstream_nodes.add(p.ref.node)

            n = Node(
                id=
                f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",  # type: ignore
                metadata=_core_wf.NodeMetadata(self._name,
                                               timeout=datetime.timedelta(),
                                               retries=RetryStrategy(0)),
                bindings=sorted(bindings, key=lambda b: b.var),
                upstream_nodes=list(upstream_nodes),  # type: ignore
                flyte_entity=node,
            )
            FlyteContextManager.current_context().compilation_state.add_node(
                n)  # type: ignore
            return self._compute_outputs(n)
        return self._condition
コード例 #14
0
ファイル: task.py プロジェクト: sauravsrijan/flytekit
    def __call__(self, *args, **input_map):
        """
        :param list[T] args: Do not specify.  Kwargs only are supported for this function.
        :param dict[str, T] input_map: Map of inputs.  Can be statically defined or OutputReference links.
        :rtype: flytekit.common.nodes.SdkNode
        """
        if len(args) > 0:
            raise _user_exceptions.FlyteAssertion(
                "When adding a task as a node in a workflow, all inputs must be specified with kwargs only.  We "
                "detected {} positional args.".format(len(args)))

        bindings, upstream_nodes = self.interface.create_bindings_for_inputs(
            input_map)

        # TODO: Remove DEADBEEF
        return _nodes.SdkNode(id=None,
                              metadata=_workflow_model.NodeMetadata(
                                  "DEADBEEF", self.metadata.timeout,
                                  self.metadata.retries),
                              bindings=sorted(bindings, key=lambda b: b.var),
                              upstream_nodes=upstream_nodes,
                              sdk_task=self)
コード例 #15
0
ファイル: reference_entity.py プロジェクト: flyteorg/flytekit
 def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
     return _workflow_model.NodeMetadata(name=extract_obj_name(self.name))
コード例 #16
0
def test_workflow_closure():
    int_type = _types.LiteralType(_types.SimpleType.INTEGER)
    typed_interface = _interface.TypedInterface(
        {'a': _interface.Variable(int_type, "description1")}, {
            'b': _interface.Variable(int_type, "description2"),
            'c': _interface.Variable(int_type, "description3")
        })

    b0 = _literals.Binding(
        'a',
        _literals.BindingData(scalar=_literals.Scalar(
            primitive=_literals.Primitive(integer=5))))
    b1 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'b')))
    b2 = _literals.Binding(
        'b',
        _literals.BindingData(promise=_types.OutputReference('my_node', 'c')))

    node_metadata = _workflow.NodeMetadata(name='node1',
                                           timeout=timedelta(seconds=10),
                                           retries=_literals.RetryStrategy(0))

    task_metadata = _task.TaskMetadata(
        True,
        _task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK,
                              "1.0.0", "python"), timedelta(days=1),
        _literals.RetryStrategy(3), "0.1.1b0", "This is deprecated!")

    cpu_resource = _task.Resources.ResourceEntry(
        _task.Resources.ResourceName.CPU, "1")
    resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])

    task = _task.TaskTemplate(
        _identifier.Identifier(_identifier.ResourceType.TASK, "project",
                               "domain", "name", "version"),
        "python",
        task_metadata,
        typed_interface, {
            'a': 1,
            'b': {
                'c': 2,
                'd': 3
            }
        },
        container=_task.Container("my_image", ["this", "is", "a", "cmd"],
                                  ["this", "is", "an", "arg"], resources, {},
                                  {}))

    task_node = _workflow.TaskNode(task.id)
    node = _workflow.Node(id='my_node',
                          metadata=node_metadata,
                          inputs=[b0],
                          upstream_node_ids=[],
                          output_aliases=[],
                          task_node=task_node)

    template = _workflow.WorkflowTemplate(
        id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project",
                                  "domain", "name", "version"),
        metadata=_workflow.WorkflowMetadata(),
        interface=typed_interface,
        nodes=[node],
        outputs=[b1, b2],
    )

    obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task])
    assert len(obj.tasks) == 1

    obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl())
    assert obj == obj2
コード例 #17
0
def create_and_link_node(
    ctx: FlyteContext,
    entity,
    interface: flyte_interface.Interface,
    timeout: Optional[datetime.timedelta] = None,
    retry_strategy: Optional[_literal_models.RetryStrategy] = None,
    **kwargs,
):
    """
    This method is used to generate a node with bindings. This is not used in the execution path.
    """
    if ctx.compilation_state is None:
        raise _user_exceptions.FlyteAssertion(
            "Cannot create node when not compiling...")

    used_inputs = set()
    bindings = []

    typed_interface = flyte_interface.transform_interface_to_typed_interface(
        interface)

    for k in sorted(interface.inputs):
        var = typed_interface.inputs[k]
        if k not in kwargs:
            raise _user_exceptions.FlyteAssertion(
                "Input was not specified for: {} of type {}".format(
                    k, var.type))
        v = kwargs[k]
        # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
        # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed
        # into the function.
        if isinstance(v, tuple):
            raise AssertionError(
                f"Variable({k}) for function({entity.name}) cannot receive a multi-valued tuple {v}."
                f" Check if the predecessor function returning more than one value?"
            )
        bindings.append(
            binding_from_python_std(ctx,
                                    var_name=k,
                                    expected_literal_type=var.type,
                                    t_value=v,
                                    t_value_type=interface.inputs[k]))
        used_inputs.add(k)

    extra_inputs = used_inputs ^ set(kwargs.keys())
    if len(extra_inputs) > 0:
        raise _user_exceptions.FlyteAssertion(
            "Too many inputs were specified for the interface.  Extra inputs were: {}"
            .format(extra_inputs))

    # Detect upstream nodes
    # These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes
    upstream_nodes = list(
        set([
            input_val.ref.node for input_val in kwargs.values()
            if isinstance(input_val, Promise)
            and input_val.ref.node_id != _common_constants.GLOBAL_INPUT_NODE_ID
        ]))

    node_metadata = _workflow_model.NodeMetadata(
        f"{entity.__module__}.{entity.name}",
        timeout or datetime.timedelta(),
        retry_strategy or _literal_models.RetryStrategy(0),
    )

    non_sdk_node = Node(
        # TODO: Better naming, probably a derivative of the function name.
        id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",
        metadata=node_metadata,
        bindings=sorted(bindings, key=lambda b: b.var),
        upstream_nodes=upstream_nodes,
        flyte_entity=entity,
    )
    ctx.compilation_state.add_node(non_sdk_node)

    if len(typed_interface.outputs) == 0:
        return VoidPromise(entity.name)

    # Create a node output object for each output, they should all point to this node of course.
    node_outputs = []
    for output_name, output_var_model in typed_interface.outputs.items():
        # TODO: If node id gets updated later, we have to make sure to update the NodeOutput model's ID, which
        #  is currently just a static str
        node_outputs.append(
            Promise(output_name, NodeOutput(node=non_sdk_node,
                                            var=output_name)))
        # Don't print this, it'll crash cuz sdk_node._upstream_node_ids might be None, but idl code will break

    return create_task_output(node_outputs, interface)
コード例 #18
0
def _get_sample_node_metadata():
    return _workflow.NodeMetadata(name="node1",
                                  timeout=timedelta(seconds=10),
                                  retries=_literals.RetryStrategy(0))
コード例 #19
0
ファイル: workflow.py プロジェクト: fediazgon/flytekit
 def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
     return _workflow_model.NodeMetadata(
         name=f"{self.__module__}.{self.name}",
         interruptible=self.workflow_metadata_defaults.interruptible,
     )
コード例 #20
0
 def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
     return _workflow_model.NodeMetadata(name=f"{self.__module__}.{self.name}")
コード例 #21
0
def test_sdk_node_from_task():
    @_tasks.inputs(a=_types.Types.Integer)
    @_tasks.outputs(b=_types.Types.Integer)
    @_tasks.python_task()
    def testy_test(wf_params, a, b):
        pass

    n = _nodes.SdkNode(
        "n",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    assert n.id == "n"
    assert len(n.inputs) == 1
    assert n.inputs[0].var == "a"
    assert n.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n.outputs) == 1
    assert "b" in n.outputs
    assert n.outputs["b"].node_id == "n"
    assert n.outputs["b"].var == "b"
    assert n.outputs["b"].sdk_node == n
    assert n.outputs["b"].sdk_type == _types.Types.Integer
    assert n.metadata.name == "abc"
    assert n.metadata.retries.retries == 3
    assert n.metadata.interruptible is None
    assert len(n.upstream_nodes) == 0
    assert len(n.upstream_node_ids) == 0
    assert len(n.output_aliases) == 0

    n2 = _nodes.SdkNode(
        "n2",
        [n],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), n.outputs.b),
            )
        ],
        _core_workflow_models.NodeMetadata("abc2",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    assert n2.id == "n2"
    assert len(n2.inputs) == 1
    assert n2.inputs[0].var == "a"
    assert n2.inputs[0].binding.promise.var == "b"
    assert n2.inputs[0].binding.promise.node_id == "n"
    assert len(n2.outputs) == 1
    assert "b" in n2.outputs
    assert n2.outputs["b"].node_id == "n2"
    assert n2.outputs["b"].var == "b"
    assert n2.outputs["b"].sdk_node == n2
    assert n2.outputs["b"].sdk_type == _types.Types.Integer
    assert n2.metadata.name == "abc2"
    assert n2.metadata.retries.retries == 3
    assert "n" in n2.upstream_node_ids
    assert n in n2.upstream_nodes
    assert len(n2.upstream_nodes) == 1
    assert len(n2.upstream_node_ids) == 1
    assert len(n2.output_aliases) == 0

    # Test right shift operator and late binding
    n3 = _nodes.SdkNode(
        "n3",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc3",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )
    n2 >> n3
    n >> n2 >> n3
    n3 << n2
    n3 << n2 << n

    assert n3.id == "n3"
    assert len(n3.inputs) == 1
    assert n3.inputs[0].var == "a"
    assert n3.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n3.outputs) == 1
    assert "b" in n3.outputs
    assert n3.outputs["b"].node_id == "n3"
    assert n3.outputs["b"].var == "b"
    assert n3.outputs["b"].sdk_node == n3
    assert n3.outputs["b"].sdk_type == _types.Types.Integer
    assert n3.metadata.name == "abc3"
    assert n3.metadata.retries.retries == 3
    assert "n2" in n3.upstream_node_ids
    assert n2 in n3.upstream_nodes
    assert len(n3.upstream_nodes) == 1
    assert len(n3.upstream_node_ids) == 1
    assert len(n3.output_aliases) == 0

    # Test left shift operator and late binding
    n4 = _nodes.SdkNode(
        "n4",
        [],
        [
            _literals.Binding(
                "a",
                _interface.BindingData.from_python_std(
                    _types.Types.Integer.to_flyte_literal_type(), 3),
            )
        ],
        _core_workflow_models.NodeMetadata("abc4",
                                           _datetime.timedelta(minutes=15),
                                           _literals.RetryStrategy(3)),
        sdk_task=testy_test,
        sdk_workflow=None,
        sdk_launch_plan=None,
        sdk_branch=None,
    )

    n4 << n3

    # Test that implicit dependencies don't cause direct dependencies
    n4 << n3 << n2 << n
    n >> n2 >> n3 >> n4

    assert n4.id == "n4"
    assert len(n4.inputs) == 1
    assert n4.inputs[0].var == "a"
    assert n4.inputs[0].binding.scalar.primitive.integer == 3
    assert len(n4.outputs) == 1
    assert "b" in n4.outputs
    assert n4.outputs["b"].node_id == "n4"
    assert n4.outputs["b"].var == "b"
    assert n4.outputs["b"].sdk_node == n4
    assert n4.outputs["b"].sdk_type == _types.Types.Integer
    assert n4.metadata.name == "abc4"
    assert n4.metadata.retries.retries == 3
    assert "n3" in n4.upstream_node_ids
    assert n3 in n4.upstream_nodes
    assert len(n4.upstream_nodes) == 1
    assert len(n4.upstream_node_ids) == 1
    assert len(n4.output_aliases) == 0

    # Add another dependency
    n4 << n2
    assert "n3" in n4.upstream_node_ids
    assert n3 in n4.upstream_nodes
    assert "n2" in n4.upstream_node_ids
    assert n2 in n4.upstream_nodes
    assert len(n4.upstream_nodes) == 2
    assert len(n4.upstream_node_ids) == 2
コード例 #22
0
 def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
     return _workflow_model.NodeMetadata(
         name=extract_obj_name(self.name),
         interruptible=self.workflow_metadata_defaults.interruptible,
     )