예제 #1
0
    def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
        """
        By the time this function is invoked, the local_execute function should have unwrapped the Promises and Flyte
        literal wrappers so that the kwargs we are working with here are now Python native literal values. This
        function is also expected to return Python native literal values.

        Since the user code within a dynamic task constitute a workflow, we have to first compile the workflow, and
        then execute that workflow.

        When running for real in production, the task would stop after the compilation step, and then create a file
        representing that newly generated workflow, instead of executing it.
        """
        ctx = FlyteContextManager.current_context()

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
            updated_exec_state = ctx.execution_state.with_params(
                mode=ExecutionState.Mode.TASK_EXECUTION)
            with FlyteContextManager.with_context(
                    ctx.with_execution_state(updated_exec_state)):
                logger.info("Executing Dynamic workflow, using raw inputs")
                return exception_scopes.user_entry_point(task_function)(
                    **kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION:
            return self.compile_into_workflow(ctx, task_function, **kwargs)

        if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION:
            return exception_scopes.user_entry_point(task_function)(**kwargs)

        raise ValueError(
            f"Invalid execution provided, execution state: {ctx.execution_state}"
        )
예제 #2
0
    def _raw_execute(self, **kwargs) -> Any:
        """
        This is called during locally run executions. Unlike array task execution on the Flyte platform, _raw_execute
        produces the full output collection.
        """
        outputs_expected = True
        if not self.interface.outputs:
            outputs_expected = False
        outputs = []

        any_input_key = (
            list(self._run_task.interface.inputs.keys())[0]
            if self._run_task.interface.inputs.items() is not None
            else None
        )

        for i in range(len(kwargs[any_input_key])):
            single_instance_inputs = {}
            for k in self.interface.inputs.keys():
                single_instance_inputs[k] = kwargs[k][i]
            o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs)
            if outputs_expected:
                outputs.append(o)

        return outputs
예제 #3
0
 def execute(self, **kwargs):
     """
     This function is here only to try to streamline the pattern between workflows and tasks. Since tasks
     call execute from dispatch_execute which is in local_execute, workflows should also call an execute inside
     local_execute. This makes mocking cleaner.
     """
     return exception_scopes.user_entry_point(
         self._workflow_function)(**kwargs)
예제 #4
0
 def execute(self, **kwargs) -> Any:
     """
     This method will be invoked to execute the task. If you do decide to override this method you must also
     handle dynamic tasks or you will no longer be able to use the task as a dynamic task generator.
     """
     if self.execution_mode == self.ExecutionBehavior.DEFAULT:
         return exception_scopes.user_entry_point(
             self._task_function)(**kwargs)
     elif self.execution_mode == self.ExecutionBehavior.DYNAMIC:
         return self.dynamic_execute(self._task_function, **kwargs)
예제 #5
0
 def _execute_map_task(self, ctx: FlyteContext, **kwargs) -> Any:
     """
     This is called during ExecutionState.Mode.TASK_EXECUTION executions, that is executions orchestrated by the
     Flyte platform. Individual instances of the map task, aka array task jobs are passed the full set of inputs but
     only produce a single output based on the map task (array task) instance. The array plugin handler will actually
     create a collection from these individual outputs as the final map task output value.
     """
     task_index = self._compute_array_job_index()
     map_task_inputs = {}
     for k in self.interface.inputs.keys():
         map_task_inputs[k] = kwargs[k][task_index]
     return exception_scopes.user_entry_point(self._run_task.execute)(**map_task_inputs)
예제 #6
0
    def compile(self, **kwargs):
        """
        Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics
        a 'closure' in the traditional sense of the word.
        """
        ctx = FlyteContextManager.current_context()
        self._input_parameters = transform_inputs_to_parameters(
            ctx, self.python_interface)
        all_nodes = []
        prefix = ctx.compilation_state.prefix if ctx.compilation_state is not None else ""

        with FlyteContextManager.with_context(
                ctx.with_compilation_state(
                    CompilationState(prefix=prefix,
                                     task_resolver=self))) as comp_ctx:
            # Construct the default input promise bindings, but then override with the provided inputs, if any
            input_kwargs = construct_input_promises(
                [k for k in self.interface.inputs.keys()])
            input_kwargs.update(kwargs)
            workflow_outputs = exception_scopes.user_entry_point(
                self._workflow_function)(**input_kwargs)
            all_nodes.extend(comp_ctx.compilation_state.nodes)

            # This little loop was added as part of the task resolver change. The task resolver interface itself is
            # more or less stateless (the future-proofing get_all_tasks function notwithstanding). However the
            # implementation of the TaskResolverMixin that this workflow class inherits from (ClassStorageTaskResolver)
            # does store state. This loop adds Tasks that are defined within the body of the workflow to the workflow
            # object itself.
            for n in comp_ctx.compilation_state.nodes:
                if isinstance(n.flyte_entity, PythonAutoContainerTask
                              ) and n.flyte_entity.task_resolver == self:
                    logger.debug(
                        f"WF {self.name} saving task {n.flyte_entity.name}")
                    self.add(n.flyte_entity)

        # Iterate through the workflow outputs
        bindings = []
        output_names = list(self.interface.outputs.keys())
        # The reason the length 1 case is separate is because the one output might be a list. We don't want to
        # iterate through the list here, instead we should let the binding creation unwrap it and make a binding
        # collection/map out of it.
        if len(output_names) == 1:
            if isinstance(workflow_outputs, tuple):
                if len(workflow_outputs) != 1:
                    raise AssertionError(
                        f"The Workflow specification indicates only one return value, received {len(workflow_outputs)}"
                    )
                if self.python_interface.output_tuple_name is None:
                    raise AssertionError(
                        "Outputs specification for Workflow does not define a tuple, but return value is a tuple"
                    )
                workflow_outputs = workflow_outputs[0]
            t = self.python_interface.outputs[output_names[0]]
            b = binding_from_python_std(
                ctx,
                output_names[0],
                self.interface.outputs[output_names[0]].type,
                workflow_outputs,
                t,
            )
            bindings.append(b)
        elif len(output_names) > 1:
            if not isinstance(workflow_outputs, tuple):
                raise AssertionError(
                    "The Workflow specification indicates multiple return values, received only one"
                )
            if len(output_names) != len(workflow_outputs):
                raise Exception(
                    f"Length mismatch {len(output_names)} vs {len(workflow_outputs)}"
                )
            for i, out in enumerate(output_names):
                if isinstance(workflow_outputs[i], ConditionalSection):
                    raise AssertionError(
                        "A Conditional block (if-else) should always end with an `else_()` clause"
                    )
                t = self.python_interface.outputs[out]
                b = binding_from_python_std(
                    ctx,
                    out,
                    self.interface.outputs[out].type,
                    workflow_outputs[i],
                    t,
                )
                bindings.append(b)

        # Save all the things necessary to create an SdkWorkflow, except for the missing project and domain
        self._nodes = all_nodes
        self._output_bindings = bindings
        if not output_names:
            return None
        if len(output_names) == 1:
            return bindings[0]
        return tuple(bindings)