def compile_into_workflow( self, ctx: FlyteContext, task_function: Callable, **kwargs ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]: with ctx.new_compilation_context(prefix="dynamic"): # TODO: Resolve circular import from flytekit.common.translator import get_serializable workflow_metadata = WorkflowMetadata( on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) defaults = WorkflowMetadataDefaults(interruptible=False) self._wf = Workflow(task_function, metadata=workflow_metadata, default_metadata=defaults) self._wf.compile(**kwargs) wf = self._wf sdk_workflow = get_serializable(ctx.serialization_settings, wf) # If no nodes were produced, let's just return the strict outputs if len(sdk_workflow.nodes) == 0: return _literal_models.LiteralMap( literals={ binding.var: binding.binding.to_literal_model() for binding in sdk_workflow._outputs }) # Gather underlying tasks/workflows that get referenced. Launch plans are handled by propeller. tasks = set() sub_workflows = set() for n in sdk_workflow.nodes: self.aggregate(tasks, sub_workflows, n) dj_spec = _dynamic_job.DynamicJobSpec( min_successes=len(sdk_workflow.nodes), tasks=list(tasks), nodes=sdk_workflow.nodes, outputs=sdk_workflow._outputs, subworkflows=list(sub_workflows), ) return dj_spec
def compile_into_workflow( self, ctx: FlyteContext, is_fast_execution: bool, task_function: Callable, **kwargs ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]: with ctx.new_compilation_context(prefix="dynamic"): # TODO: Resolve circular import from flytekit.common.translator import get_serializable workflow_metadata = WorkflowMetadata( on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) defaults = WorkflowMetadataDefaults(interruptible=False) self._wf = PythonFunctionWorkflow(task_function, metadata=workflow_metadata, default_metadata=defaults) self._wf.compile(**kwargs) wf = self._wf sdk_workflow = get_serializable(OrderedDict(), ctx.serialization_settings, wf, is_fast_execution) # If no nodes were produced, let's just return the strict outputs if len(sdk_workflow.nodes) == 0: return _literal_models.LiteralMap( literals={ binding.var: binding.binding.to_literal_model() for binding in sdk_workflow._outputs }) # Gather underlying tasks/workflows that get referenced. Launch plans are handled by propeller. tasks = set() sub_workflows = set() for n in sdk_workflow.nodes: self.aggregate(tasks, sub_workflows, n) if is_fast_execution: if (not ctx.execution_state or not ctx.execution_state.additional_context or not ctx.execution_state.additional_context.get( "dynamic_addl_distro")): raise AssertionError( "Compilation for a dynamic workflow called in fast execution mode but no additional code " "distribution could be retrieved") logger.warn( f"ctx.execution_state.additional_context {ctx.execution_state.additional_context}" ) sanitized_tasks = set() for task in tasks: sanitized_args = [] for arg in task.container.args: if arg == "{{ .remote_package_path }}": sanitized_args.append( ctx.execution_state.additional_context.get( "dynamic_addl_distro")) elif arg == "{{ .dest_dir }}": sanitized_args.append( ctx.execution_state.additional_context.get( "dynamic_dest_dir", ".")) else: sanitized_args.append(arg) del task.container.args[:] task.container.args.extend(sanitized_args) sanitized_tasks.add(task) tasks = sanitized_tasks dj_spec = _dynamic_job.DynamicJobSpec( min_successes=len(sdk_workflow.nodes), tasks=list(tasks), nodes=sdk_workflow.nodes, outputs=sdk_workflow._outputs, subworkflows=list(sub_workflows), ) return dj_spec