コード例 #1
0
    def compile_into_workflow(
        self, ctx: FlyteContext, is_fast_execution: bool,
        task_function: Callable, **kwargs
    ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]:
        if not ctx.compilation_state:
            cs = ctx.new_compilation_state("dynamic")
        else:
            cs = ctx.compilation_state.with_params(prefix="dynamic")

        with FlyteContextManager.with_context(ctx.with_compilation_state(cs)):
            # TODO: Resolve circular import
            from flytekit.common.translator import get_serializable

            workflow_metadata = WorkflowMetadata(
                on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY)
            defaults = WorkflowMetadataDefaults(interruptible=False)

            self._wf = PythonFunctionWorkflow(task_function,
                                              metadata=workflow_metadata,
                                              default_metadata=defaults)
            self._wf.compile(**kwargs)

            wf = self._wf
            sdk_workflow = get_serializable(OrderedDict(),
                                            ctx.serialization_settings, wf,
                                            is_fast_execution)

            # If no nodes were produced, let's just return the strict outputs
            if len(sdk_workflow.nodes) == 0:
                return _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in sdk_workflow._outputs
                    })

            # Gather underlying tasks/workflows that get referenced. Launch plans are handled by propeller.
            tasks = set()
            sub_workflows = set()
            for n in sdk_workflow.nodes:
                self.aggregate(tasks, sub_workflows, n)

            if is_fast_execution:
                if (not ctx.execution_state
                        or not ctx.execution_state.additional_context
                        or not ctx.execution_state.additional_context.get(
                            "dynamic_addl_distro")):
                    raise AssertionError(
                        "Compilation for a dynamic workflow called in fast execution mode but no additional code "
                        "distribution could be retrieved")
                logger.warn(
                    f"ctx.execution_state.additional_context {ctx.execution_state.additional_context}"
                )
                sanitized_tasks = set()
                for task in tasks:
                    sanitized_args = []
                    for arg in task.container.args:
                        if arg == "{{ .remote_package_path }}":
                            sanitized_args.append(
                                ctx.execution_state.additional_context.get(
                                    "dynamic_addl_distro"))
                        elif arg == "{{ .dest_dir }}":
                            sanitized_args.append(
                                ctx.execution_state.additional_context.get(
                                    "dynamic_dest_dir", "."))
                        else:
                            sanitized_args.append(arg)
                    del task.container.args[:]
                    task.container.args.extend(sanitized_args)
                    sanitized_tasks.add(task)

                tasks = sanitized_tasks

            dj_spec = _dynamic_job.DynamicJobSpec(
                min_successes=len(sdk_workflow.nodes),
                tasks=list(tasks),
                nodes=sdk_workflow.nodes,
                outputs=sdk_workflow._outputs,
                subworkflows=list(sub_workflows),
            )

            return dj_spec
コード例 #2
0
    def compile_into_workflow(
        self, ctx: FlyteContext, task_function: Callable, **kwargs
    ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]:
        if not ctx.compilation_state:
            cs = ctx.new_compilation_state("dynamic")
        else:
            cs = ctx.compilation_state.with_params(prefix="dynamic")

        with FlyteContextManager.with_context(ctx.with_compilation_state(cs)):
            # TODO: Resolve circular import
            from flytekit.common.translator import get_serializable

            workflow_metadata = WorkflowMetadata(
                on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY)
            defaults = WorkflowMetadataDefaults(
                interruptible=self.metadata.interruptible if self.metadata.
                interruptible is not None else False)

            self._wf = PythonFunctionWorkflow(task_function,
                                              metadata=workflow_metadata,
                                              default_metadata=defaults)
            self._wf.compile(**kwargs)

            wf = self._wf
            model_entities = OrderedDict()
            # See comment on reference entity checking a bit down below in this function.
            # This is the only circular dependency between the translator.py module and the rest of the flytekit
            # authoring experience.
            workflow_spec: admin_workflow_models.WorkflowSpec = get_serializable(
                model_entities, ctx.serialization_settings, wf)

            # If no nodes were produced, let's just return the strict outputs
            if len(workflow_spec.template.nodes) == 0:
                return _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in workflow_spec.template.outputs
                    })

            # This is not great. The translator.py module is relied on here (see comment above) to get the tasks and
            # subworkflow definitions. However we want to ensure that reference tasks and reference sub workflows are
            # not used.
            # TODO: Replace None with a class.
            for value in model_entities.values():
                if value is None:
                    raise Exception(
                        "Reference tasks are not allowed in the dynamic - a network call is necessary "
                        "in order to retrieve the structure of the reference task."
                    )

            # Gather underlying TaskTemplates that get referenced. Launch plans are handled by propeller. Subworkflows
            # should already be in the workflow spec.
            tts = [
                v.template for v in model_entities.values()
                if isinstance(v, task_models.TaskSpec)
            ]

            if ctx.serialization_settings.should_fast_serialize():
                if (not ctx.execution_state
                        or not ctx.execution_state.additional_context
                        or not ctx.execution_state.additional_context.get(
                            "dynamic_addl_distro")):
                    raise AssertionError(
                        "Compilation for a dynamic workflow called in fast execution mode but no additional code "
                        "distribution could be retrieved")
                logger.warn(
                    f"ctx.execution_state.additional_context {ctx.execution_state.additional_context}"
                )
                for task_template in tts:
                    sanitized_args = []
                    for arg in task_template.container.args:
                        if arg == "{{ .remote_package_path }}":
                            sanitized_args.append(
                                ctx.execution_state.additional_context.get(
                                    "dynamic_addl_distro"))
                        elif arg == "{{ .dest_dir }}":
                            sanitized_args.append(
                                ctx.execution_state.additional_context.get(
                                    "dynamic_dest_dir", "."))
                        else:
                            sanitized_args.append(arg)
                    del task_template.container.args[:]
                    task_template.container.args.extend(sanitized_args)

            dj_spec = _dynamic_job.DynamicJobSpec(
                min_successes=len(workflow_spec.template.nodes),
                tasks=tts,
                nodes=workflow_spec.template.nodes,
                outputs=workflow_spec.template.outputs,
                subworkflows=workflow_spec.sub_workflows,
            )

            return dj_spec
コード例 #3
0
    def compile_into_workflow(
        self, ctx: FlyteContext, task_function: Callable, **kwargs
    ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]:
        """
        In the case of dynamic workflows, this function will produce a workflow definition at execution time which will
        then proceed to be executed.
        """
        # TODO: circular import
        from flytekit.core.task import ReferenceTask

        if not ctx.compilation_state:
            cs = ctx.new_compilation_state(prefix="d")
        else:
            cs = ctx.compilation_state.with_params(prefix="d")

        with FlyteContextManager.with_context(ctx.with_compilation_state(cs)):
            # TODO: Resolve circular import
            from flytekit.tools.translator import get_serializable

            workflow_metadata = WorkflowMetadata(
                on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY)
            defaults = WorkflowMetadataDefaults(
                interruptible=self.metadata.interruptible if self.metadata.
                interruptible is not None else False)

            self._wf = PythonFunctionWorkflow(task_function,
                                              metadata=workflow_metadata,
                                              default_metadata=defaults)
            self._wf.compile(**kwargs)

            wf = self._wf
            model_entities = OrderedDict()
            # See comment on reference entity checking a bit down below in this function.
            # This is the only circular dependency between the translator.py module and the rest of the flytekit
            # authoring experience.
            workflow_spec: admin_workflow_models.WorkflowSpec = get_serializable(
                model_entities, ctx.serialization_settings, wf)

            # If no nodes were produced, let's just return the strict outputs
            if len(workflow_spec.template.nodes) == 0:
                return _literal_models.LiteralMap(
                    literals={
                        binding.var: binding.binding.to_literal_model()
                        for binding in workflow_spec.template.outputs
                    })

            # Gather underlying TaskTemplates that get referenced.
            tts = []
            for entity, model in model_entities.items():
                # We only care about gathering tasks here. Launch plans are handled by
                # propeller. Subworkflows should already be in the workflow spec.
                if not isinstance(entity, Task) and not isinstance(
                        entity, task_models.TaskTemplate):
                    continue

                # Handle FlyteTask
                if isinstance(entity, task_models.TaskTemplate):
                    tts.append(entity)
                    continue

                # We are currently not supporting reference tasks since these will
                # require a network call to flyteadmin to populate the TaskTemplate
                # model
                if isinstance(entity, ReferenceTask):
                    raise Exception(
                        "Reference tasks are currently unsupported within dynamic tasks"
                    )

                if not isinstance(model, task_models.TaskSpec):
                    raise TypeError(
                        f"Unexpected type for serialized form of task. Expected {task_models.TaskSpec}, but got {type(model)}"
                    )

                # Store the valid task template so that we can pass it to the
                # DynamicJobSpec later
                tts.append(model.template)

            dj_spec = _dynamic_job.DynamicJobSpec(
                min_successes=len(workflow_spec.template.nodes),
                tasks=tts,
                nodes=workflow_spec.template.nodes,
                outputs=workflow_spec.template.outputs,
                subworkflows=workflow_spec.sub_workflows,
            )

            return dj_spec