def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: """ By the time this function is invoked, the local_execute function should have unwrapped the Promises and Flyte literal wrappers so that the kwargs we are working with here are now Python native literal values. This function is also expected to return Python native literal values. Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and then execute that workflow. When running for real in production, the task would stop after the compilation step, and then create a file representing that newly generated workflow, instead of executing it. """ ctx = FlyteContextManager.current_context() if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: updated_exec_state = ctx.execution_state.with_params( mode=ExecutionState.Mode.TASK_EXECUTION) with FlyteContextManager.with_context( ctx.with_execution_state(updated_exec_state)): logger.info("Executing Dynamic workflow, using raw inputs") return exception_scopes.user_entry_point(task_function)( **kwargs) if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: return self.compile_into_workflow(ctx, task_function, **kwargs) if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION: return exception_scopes.user_entry_point(task_function)(**kwargs) raise ValueError( f"Invalid execution provided, execution state: {ctx.execution_state}" )
def _raw_execute(self, **kwargs) -> Any: """ This is called during locally run executions. Unlike array task execution on the Flyte platform, _raw_execute produces the full output collection. """ outputs_expected = True if not self.interface.outputs: outputs_expected = False outputs = [] any_input_key = ( list(self._run_task.interface.inputs.keys())[0] if self._run_task.interface.inputs.items() is not None else None ) for i in range(len(kwargs[any_input_key])): single_instance_inputs = {} for k in self.interface.inputs.keys(): single_instance_inputs[k] = kwargs[k][i] o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs) if outputs_expected: outputs.append(o) return outputs
def execute(self, **kwargs): """ This function is here only to try to streamline the pattern between workflows and tasks. Since tasks call execute from dispatch_execute which is in local_execute, workflows should also call an execute inside local_execute. This makes mocking cleaner. """ return exception_scopes.user_entry_point( self._workflow_function)(**kwargs)
def execute(self, **kwargs) -> Any: """ This method will be invoked to execute the task. If you do decide to override this method you must also handle dynamic tasks or you will no longer be able to use the task as a dynamic task generator. """ if self.execution_mode == self.ExecutionBehavior.DEFAULT: return exception_scopes.user_entry_point( self._task_function)(**kwargs) elif self.execution_mode == self.ExecutionBehavior.DYNAMIC: return self.dynamic_execute(self._task_function, **kwargs)
def _execute_map_task(self, ctx: FlyteContext, **kwargs) -> Any: """ This is called during ExecutionState.Mode.TASK_EXECUTION executions, that is executions orchestrated by the Flyte platform. Individual instances of the map task, aka array task jobs are passed the full set of inputs but only produce a single output based on the map task (array task) instance. The array plugin handler will actually create a collection from these individual outputs as the final map task output value. """ task_index = self._compute_array_job_index() map_task_inputs = {} for k in self.interface.inputs.keys(): map_task_inputs[k] = kwargs[k][task_index] return exception_scopes.user_entry_point(self._run_task.execute)(**map_task_inputs)
def compile(self, **kwargs): """ Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics a 'closure' in the traditional sense of the word. """ ctx = FlyteContextManager.current_context() self._input_parameters = transform_inputs_to_parameters( ctx, self.python_interface) all_nodes = [] prefix = ctx.compilation_state.prefix if ctx.compilation_state is not None else "" with FlyteContextManager.with_context( ctx.with_compilation_state( CompilationState(prefix=prefix, task_resolver=self))) as comp_ctx: # Construct the default input promise bindings, but then override with the provided inputs, if any input_kwargs = construct_input_promises( [k for k in self.interface.inputs.keys()]) input_kwargs.update(kwargs) workflow_outputs = exception_scopes.user_entry_point( self._workflow_function)(**input_kwargs) all_nodes.extend(comp_ctx.compilation_state.nodes) # This little loop was added as part of the task resolver change. The task resolver interface itself is # more or less stateless (the future-proofing get_all_tasks function notwithstanding). However the # implementation of the TaskResolverMixin that this workflow class inherits from (ClassStorageTaskResolver) # does store state. This loop adds Tasks that are defined within the body of the workflow to the workflow # object itself. for n in comp_ctx.compilation_state.nodes: if isinstance(n.flyte_entity, PythonAutoContainerTask ) and n.flyte_entity.task_resolver == self: logger.debug( f"WF {self.name} saving task {n.flyte_entity.name}") self.add(n.flyte_entity) # Iterate through the workflow outputs bindings = [] output_names = list(self.interface.outputs.keys()) # The reason the length 1 case is separate is because the one output might be a list. We don't want to # iterate through the list here, instead we should let the binding creation unwrap it and make a binding # collection/map out of it. if len(output_names) == 1: if isinstance(workflow_outputs, tuple): if len(workflow_outputs) != 1: raise AssertionError( f"The Workflow specification indicates only one return value, received {len(workflow_outputs)}" ) if self.python_interface.output_tuple_name is None: raise AssertionError( "Outputs specification for Workflow does not define a tuple, but return value is a tuple" ) workflow_outputs = workflow_outputs[0] t = self.python_interface.outputs[output_names[0]] b = binding_from_python_std( ctx, output_names[0], self.interface.outputs[output_names[0]].type, workflow_outputs, t, ) bindings.append(b) elif len(output_names) > 1: if not isinstance(workflow_outputs, tuple): raise AssertionError( "The Workflow specification indicates multiple return values, received only one" ) if len(output_names) != len(workflow_outputs): raise Exception( f"Length mismatch {len(output_names)} vs {len(workflow_outputs)}" ) for i, out in enumerate(output_names): if isinstance(workflow_outputs[i], ConditionalSection): raise AssertionError( "A Conditional block (if-else) should always end with an `else_()` clause" ) t = self.python_interface.outputs[out] b = binding_from_python_std( ctx, out, self.interface.outputs[out].type, workflow_outputs[i], t, ) bindings.append(b) # Save all the things necessary to create an SdkWorkflow, except for the missing project and domain self._nodes = all_nodes self._output_bindings = bindings if not output_names: return None if len(output_names) == 1: return bindings[0] return tuple(bindings)