def dummy_node(node_id) -> Node: n = Node( node_id, metadata=None, bindings=[], upstream_nodes=[], flyte_entity=SQLTask(name="x", query_template="x", inputs={}), ) n._id = node_id return n
def end_branch(self) -> Union[Condition, Promise]: """ This should be invoked after every branch has been visited """ ctx = FlyteContext.current_context() if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: """ In case of Local workflow execution, we should first mark the branch as complete, then Then we first check for if this is the last case, In case this is the last case, we return the output from the selected case - A case should always be selected (see start_branch) If this is not the last case, we should return the condition so that further chaining can be done """ # Let us mark the execution state as complete ctx.execution_state.branch_complete() if self._last_case: ctx.execution_state.exit_conditional_section() if self._selected_case.output_promise is None and self._selected_case.err is None: raise AssertionError("Bad conditional statements, did not resolve in a promise") elif self._selected_case.output_promise is not None: return self._selected_case.output_promise raise ValueError(self._selected_case.err) return self._condition elif ctx.compilation_state: ######## # COMPILATION MODE """ In case this is not local workflow execution then, we should check if this is the last case. If so then return the promise, else return the condition """ if self._last_case: ctx.compilation_state.exit_conditional_section() # branch_nodes = ctx.compilation_state.nodes node, promises = to_branch_node(self._name, self) # Verify branch_nodes == nodes in bn bindings: typing.List[Binding] = [] upstream_nodes = set() for p in promises: if not p.is_ready: bindings.append(Binding(var=p.var, binding=BindingData(promise=p.ref))) upstream_nodes.add(p.ref.node) n = Node( id=f"{ctx.compilation_state.prefix}node-{len(ctx.compilation_state.nodes)}", metadata=_core_wf.NodeMetadata(self._name, timeout=datetime.timedelta(), retries=RetryStrategy(0)), bindings=sorted(bindings, key=lambda b: b.var), upstream_nodes=list(upstream_nodes), # type: ignore flyte_entity=node, ) ctx.compilation_state.add_node(n) return self._compute_outputs(n) return self._condition raise AssertionError("Branches can only be invoked within a workflow context!")
def end_branch( self ) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidPromise]]: """ This should be invoked after every branch has been visited. In case this is not local workflow execution then, we should check if this is the last case. If so then return the promise, else return the condition """ if self._last_case: # We have completed the conditional section, lets pop off the branch context FlyteContextManager.pop_context() ctx = FlyteContextManager.current_context() # Question: This is commented out because we don't need it? Nodes created in the conditional # compilation state are captured in the to_case_block? Always? # Is this still true of nested conditionals? Is that why propeller compiler is complaining? # branch_nodes = ctx.compilation_state.nodes node, promises = to_branch_node(self._name, self) # Verify branch_nodes == nodes in bn bindings: typing.List[Binding] = [] upstream_nodes = set() for p in promises: if not p.is_ready: bindings.append( Binding(var=p.var, binding=BindingData(promise=p.ref))) upstream_nodes.add(p.ref.node) n = Node( id= f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}", # type: ignore metadata=_core_wf.NodeMetadata(self._name, timeout=datetime.timedelta(), retries=RetryStrategy(0)), bindings=sorted(bindings, key=lambda b: b.var), upstream_nodes=list(upstream_nodes), # type: ignore flyte_entity=node, ) FlyteContextManager.current_context().compilation_state.add_node( n) # type: ignore return self._compute_outputs(n) return self._condition
def create_and_link_node( ctx: FlyteContext, entity, interface: flyte_interface.Interface, timeout: Optional[datetime.timedelta] = None, retry_strategy: Optional[_literal_models.RetryStrategy] = None, **kwargs, ): """ This method is used to generate a node with bindings. This is not used in the execution path. """ if ctx.compilation_state is None: raise _user_exceptions.FlyteAssertion( "Cannot create node when not compiling...") used_inputs = set() bindings = [] typed_interface = flyte_interface.transform_interface_to_typed_interface( interface) for k in sorted(interface.inputs): var = typed_interface.inputs[k] if k not in kwargs: raise _user_exceptions.FlyteAssertion( "Input was not specified for: {} of type {}".format( k, var.type)) v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed # into the function. if isinstance(v, tuple): raise AssertionError( f"Variable({k}) for function({entity.name}) cannot receive a multi-valued tuple {v}." f" Check if the predecessor function returning more than one value?" ) bindings.append( binding_from_python_std(ctx, var_name=k, expected_literal_type=var.type, t_value=v, t_value_type=interface.inputs[k])) used_inputs.add(k) extra_inputs = used_inputs ^ set(kwargs.keys()) if len(extra_inputs) > 0: raise _user_exceptions.FlyteAssertion( "Too many inputs were specified for the interface. Extra inputs were: {}" .format(extra_inputs)) # Detect upstream nodes # These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes upstream_nodes = list( set([ input_val.ref.node for input_val in kwargs.values() if isinstance(input_val, Promise) and input_val.ref.node_id != _common_constants.GLOBAL_INPUT_NODE_ID ])) node_metadata = _workflow_model.NodeMetadata( f"{entity.__module__}.{entity.name}", timeout or datetime.timedelta(), retry_strategy or _literal_models.RetryStrategy(0), ) non_sdk_node = Node( # TODO: Better naming, probably a derivative of the function name. id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}", metadata=node_metadata, bindings=sorted(bindings, key=lambda b: b.var), upstream_nodes=upstream_nodes, flyte_entity=entity, ) ctx.compilation_state.add_node(non_sdk_node) if len(typed_interface.outputs) == 0: return VoidPromise(entity.name) # Create a node output object for each output, they should all point to this node of course. node_outputs = [] for output_name, output_var_model in typed_interface.outputs.items(): # TODO: If node id gets updated later, we have to make sure to update the NodeOutput model's ID, which # is currently just a static str node_outputs.append( Promise(output_name, NodeOutput(node=non_sdk_node, var=output_name))) # Don't print this, it'll crash cuz sdk_node._upstream_node_ids might be None, but idl code will break return create_task_output(node_outputs, interface)
) from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.reference_entity import ReferenceEntity, WorkflowReference from flytekit.core.tracker import extract_task_module from flytekit.core.type_engine import TypeEngine from flytekit.exceptions import scopes as exception_scopes from flytekit.exceptions.user import FlyteValidationException, FlyteValueException from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model GLOBAL_START_NODE = Node( id=_common_constants.GLOBAL_INPUT_NODE_ID, metadata=None, bindings=[], upstream_nodes=[], flyte_entity=None, ) class WorkflowFailurePolicy(Enum): """ Defines the behavior for a workflow execution in the case of an observed node execution failure. By default, a workflow execution will immediately enter a failed state if a component node fails. """ #: Causes the entire workflow execution to fail once a component node fails. FAIL_IMMEDIATELY = _workflow_model.WorkflowMetadata.OnFailurePolicy.FAIL_IMMEDIATELY #: Will proceed to run any remaining runnable nodes once a component node fails.