示例#1
0
    def dummy_node(node_id) -> Node:
        n = Node(
            node_id,
            metadata=None,
            bindings=[],
            upstream_nodes=[],
            flyte_entity=SQLTask(name="x", query_template="x", inputs={}),
        )

        n._id = node_id
        return n
示例#2
0
    def end_branch(self) -> Union[Condition, Promise]:
        """
        This should be invoked after every branch has been visited
        """
        ctx = FlyteContext.current_context()
        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            """
            In case of Local workflow execution, we should first mark the branch as complete, then
            Then we first check for if this is the last case,
            In case this is the last case, we return the output from the selected case - A case should always
            be selected (see start_branch)
            If this is not the last case, we should return the condition so that further chaining can be done
            """
            # Let us mark the execution state as complete
            ctx.execution_state.branch_complete()
            if self._last_case:
                ctx.execution_state.exit_conditional_section()
                if self._selected_case.output_promise is None and self._selected_case.err is None:
                    raise AssertionError("Bad conditional statements, did not resolve in a promise")
                elif self._selected_case.output_promise is not None:
                    return self._selected_case.output_promise
                raise ValueError(self._selected_case.err)
            return self._condition
        elif ctx.compilation_state:
            ########
            # COMPILATION MODE
            """
            In case this is not local workflow execution then, we should check if this is the last case.
            If so then return the promise, else return the condition
            """
            if self._last_case:
                ctx.compilation_state.exit_conditional_section()
                # branch_nodes = ctx.compilation_state.nodes
                node, promises = to_branch_node(self._name, self)
                # Verify branch_nodes == nodes in bn
                bindings: typing.List[Binding] = []
                upstream_nodes = set()
                for p in promises:
                    if not p.is_ready:
                        bindings.append(Binding(var=p.var, binding=BindingData(promise=p.ref)))
                        upstream_nodes.add(p.ref.node)

                n = Node(
                    id=f"{ctx.compilation_state.prefix}node-{len(ctx.compilation_state.nodes)}",
                    metadata=_core_wf.NodeMetadata(self._name, timeout=datetime.timedelta(), retries=RetryStrategy(0)),
                    bindings=sorted(bindings, key=lambda b: b.var),
                    upstream_nodes=list(upstream_nodes),  # type: ignore
                    flyte_entity=node,
                )
                ctx.compilation_state.add_node(n)
                return self._compute_outputs(n)
            return self._condition

        raise AssertionError("Branches can only be invoked within a workflow context!")
示例#3
0
    def end_branch(
        self
    ) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidPromise]]:
        """
        This should be invoked after every branch has been visited.
        In case this is not local workflow execution then, we should check if this is the last case.
        If so then return the promise, else return the condition
        """
        if self._last_case:
            # We have completed the conditional section, lets pop off the branch context
            FlyteContextManager.pop_context()
            ctx = FlyteContextManager.current_context()
            # Question: This is commented out because we don't need it? Nodes created in the conditional
            #   compilation state are captured in the to_case_block? Always?
            #   Is this still true of nested conditionals? Is that why propeller compiler is complaining?
            # branch_nodes = ctx.compilation_state.nodes
            node, promises = to_branch_node(self._name, self)
            # Verify branch_nodes == nodes in bn
            bindings: typing.List[Binding] = []
            upstream_nodes = set()
            for p in promises:
                if not p.is_ready:
                    bindings.append(
                        Binding(var=p.var, binding=BindingData(promise=p.ref)))
                    upstream_nodes.add(p.ref.node)

            n = Node(
                id=
                f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",  # type: ignore
                metadata=_core_wf.NodeMetadata(self._name,
                                               timeout=datetime.timedelta(),
                                               retries=RetryStrategy(0)),
                bindings=sorted(bindings, key=lambda b: b.var),
                upstream_nodes=list(upstream_nodes),  # type: ignore
                flyte_entity=node,
            )
            FlyteContextManager.current_context().compilation_state.add_node(
                n)  # type: ignore
            return self._compute_outputs(n)
        return self._condition
示例#4
0
def create_and_link_node(
    ctx: FlyteContext,
    entity,
    interface: flyte_interface.Interface,
    timeout: Optional[datetime.timedelta] = None,
    retry_strategy: Optional[_literal_models.RetryStrategy] = None,
    **kwargs,
):
    """
    This method is used to generate a node with bindings. This is not used in the execution path.
    """
    if ctx.compilation_state is None:
        raise _user_exceptions.FlyteAssertion(
            "Cannot create node when not compiling...")

    used_inputs = set()
    bindings = []

    typed_interface = flyte_interface.transform_interface_to_typed_interface(
        interface)

    for k in sorted(interface.inputs):
        var = typed_interface.inputs[k]
        if k not in kwargs:
            raise _user_exceptions.FlyteAssertion(
                "Input was not specified for: {} of type {}".format(
                    k, var.type))
        v = kwargs[k]
        # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
        # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed
        # into the function.
        if isinstance(v, tuple):
            raise AssertionError(
                f"Variable({k}) for function({entity.name}) cannot receive a multi-valued tuple {v}."
                f" Check if the predecessor function returning more than one value?"
            )
        bindings.append(
            binding_from_python_std(ctx,
                                    var_name=k,
                                    expected_literal_type=var.type,
                                    t_value=v,
                                    t_value_type=interface.inputs[k]))
        used_inputs.add(k)

    extra_inputs = used_inputs ^ set(kwargs.keys())
    if len(extra_inputs) > 0:
        raise _user_exceptions.FlyteAssertion(
            "Too many inputs were specified for the interface.  Extra inputs were: {}"
            .format(extra_inputs))

    # Detect upstream nodes
    # These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes
    upstream_nodes = list(
        set([
            input_val.ref.node for input_val in kwargs.values()
            if isinstance(input_val, Promise)
            and input_val.ref.node_id != _common_constants.GLOBAL_INPUT_NODE_ID
        ]))

    node_metadata = _workflow_model.NodeMetadata(
        f"{entity.__module__}.{entity.name}",
        timeout or datetime.timedelta(),
        retry_strategy or _literal_models.RetryStrategy(0),
    )

    non_sdk_node = Node(
        # TODO: Better naming, probably a derivative of the function name.
        id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",
        metadata=node_metadata,
        bindings=sorted(bindings, key=lambda b: b.var),
        upstream_nodes=upstream_nodes,
        flyte_entity=entity,
    )
    ctx.compilation_state.add_node(non_sdk_node)

    if len(typed_interface.outputs) == 0:
        return VoidPromise(entity.name)

    # Create a node output object for each output, they should all point to this node of course.
    node_outputs = []
    for output_name, output_var_model in typed_interface.outputs.items():
        # TODO: If node id gets updated later, we have to make sure to update the NodeOutput model's ID, which
        #  is currently just a static str
        node_outputs.append(
            Promise(output_name, NodeOutput(node=non_sdk_node,
                                            var=output_name)))
        # Don't print this, it'll crash cuz sdk_node._upstream_node_ids might be None, but idl code will break

    return create_task_output(node_outputs, interface)
示例#5
0
)
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceEntity, WorkflowReference
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions import scopes as exception_scopes
from flytekit.exceptions.user import FlyteValidationException, FlyteValueException
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model

GLOBAL_START_NODE = Node(
    id=_common_constants.GLOBAL_INPUT_NODE_ID,
    metadata=None,
    bindings=[],
    upstream_nodes=[],
    flyte_entity=None,
)


class WorkflowFailurePolicy(Enum):
    """
    Defines the behavior for a workflow execution in the case of an observed node execution failure. By default, a
    workflow execution will immediately enter a failed state if a component node fails.
    """

    #: Causes the entire workflow execution to fail once a component node fails.
    FAIL_IMMEDIATELY = _workflow_model.WorkflowMetadata.OnFailurePolicy.FAIL_IMMEDIATELY

    #: Will proceed to run any remaining runnable nodes once a component node fails.