def get_serializable_references( entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, ) -> FlyteControlPlaneEntity: # TODO: This entire function isn't necessary. We should just return None or raise an Exception or something. # Reference entities should already exist on the Admin control plane - they should not be serialized/registered # again. Indeed we don't actually have enough information to serialize it properly. if isinstance(entity, ReferenceTask): cp_entity = SdkTask( type="ignore", metadata=TaskMetadata().to_taskmetadata_model(), interface=entity.interface, custom={}, container=None, ) elif isinstance(entity, ReferenceWorkflow): workflow_metadata = WorkflowMetadata( on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) cp_entity = SdkWorkflow( nodes=[], # Fake an empty list for nodes, id=entity.reference.id, metadata=workflow_metadata, metadata_defaults=workflow_model.WorkflowMetadataDefaults(), interface=entity.interface, output_bindings=[], ) elif isinstance(entity, ReferenceLaunchPlan): cp_entity = SdkLaunchPlan( workflow_id=None, entity_metadata=_launch_plan_models.LaunchPlanMetadata( schedule=None, notifications=[]), default_inputs=interface_models.ParameterMap({}), fixed_inputs=literal_models.LiteralMap({}), labels=_common_models.Labels({}), annotations=_common_models.Annotations({}), auth_role=_common_models.AuthRole(assumable_iam_role="fake:role"), raw_output_data_config=RawOutputDataConfig(""), ) # Because of how SdkNodes work, it needs one of these interfaces # Hopefully this is more trickery that can be cleaned up in the future cp_entity._interface = TypedInterface.promote_from_model( entity.interface) else: raise Exception("Invalid reference type when serializing") # Make sure we don't serialize this cp_entity._has_registered = True cp_entity.assign_name(entity.id.name) cp_entity._id = entity.id return cp_entity
def compile_into_workflow( self, ctx: FlyteContext, task_function: Callable, **kwargs ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]: with ctx.new_compilation_context(prefix="dynamic"): # TODO: Resolve circular import from flytekit.common.translator import get_serializable workflow_metadata = WorkflowMetadata( on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) defaults = WorkflowMetadataDefaults(interruptible=False) self._wf = Workflow(task_function, metadata=workflow_metadata, default_metadata=defaults) self._wf.compile(**kwargs) wf = self._wf sdk_workflow = get_serializable(ctx.serialization_settings, wf) # If no nodes were produced, let's just return the strict outputs if len(sdk_workflow.nodes) == 0: return _literal_models.LiteralMap( literals={ binding.var: binding.binding.to_literal_model() for binding in sdk_workflow._outputs }) # Gather underlying tasks/workflows that get referenced. Launch plans are handled by propeller. tasks = set() sub_workflows = set() for n in sdk_workflow.nodes: self.aggregate(tasks, sub_workflows, n) dj_spec = _dynamic_job.DynamicJobSpec( min_successes=len(sdk_workflow.nodes), tasks=list(tasks), nodes=sdk_workflow.nodes, outputs=sdk_workflow._outputs, subworkflows=list(sub_workflows), ) return dj_spec
def compile_into_workflow( self, ctx: FlyteContext, task_function: Callable, **kwargs ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]: if not ctx.compilation_state: cs = ctx.new_compilation_state("dynamic") else: cs = ctx.compilation_state.with_params(prefix="dynamic") with FlyteContextManager.with_context(ctx.with_compilation_state(cs)): # TODO: Resolve circular import from flytekit.common.translator import get_serializable workflow_metadata = WorkflowMetadata( on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) defaults = WorkflowMetadataDefaults( interruptible=self.metadata.interruptible if self.metadata. interruptible is not None else False) self._wf = PythonFunctionWorkflow(task_function, metadata=workflow_metadata, default_metadata=defaults) self._wf.compile(**kwargs) wf = self._wf model_entities = OrderedDict() # See comment on reference entity checking a bit down below in this function. # This is the only circular dependency between the translator.py module and the rest of the flytekit # authoring experience. workflow_spec: admin_workflow_models.WorkflowSpec = get_serializable( model_entities, ctx.serialization_settings, wf) # If no nodes were produced, let's just return the strict outputs if len(workflow_spec.template.nodes) == 0: return _literal_models.LiteralMap( literals={ binding.var: binding.binding.to_literal_model() for binding in workflow_spec.template.outputs }) # This is not great. The translator.py module is relied on here (see comment above) to get the tasks and # subworkflow definitions. However we want to ensure that reference tasks and reference sub workflows are # not used. # TODO: Replace None with a class. for value in model_entities.values(): if value is None: raise Exception( "Reference tasks are not allowed in the dynamic - a network call is necessary " "in order to retrieve the structure of the reference task." ) # Gather underlying TaskTemplates that get referenced. Launch plans are handled by propeller. Subworkflows # should already be in the workflow spec. tts = [ v.template for v in model_entities.values() if isinstance(v, task_models.TaskSpec) ] if ctx.serialization_settings.should_fast_serialize(): if (not ctx.execution_state or not ctx.execution_state.additional_context or not ctx.execution_state.additional_context.get( "dynamic_addl_distro")): raise AssertionError( "Compilation for a dynamic workflow called in fast execution mode but no additional code " "distribution could be retrieved") logger.warn( f"ctx.execution_state.additional_context {ctx.execution_state.additional_context}" ) for task_template in tts: sanitized_args = [] for arg in task_template.container.args: if arg == "{{ .remote_package_path }}": sanitized_args.append( ctx.execution_state.additional_context.get( "dynamic_addl_distro")) elif arg == "{{ .dest_dir }}": sanitized_args.append( ctx.execution_state.additional_context.get( "dynamic_dest_dir", ".")) else: sanitized_args.append(arg) del task_template.container.args[:] task_template.container.args.extend(sanitized_args) dj_spec = _dynamic_job.DynamicJobSpec( min_successes=len(workflow_spec.template.nodes), tasks=tts, nodes=workflow_spec.template.nodes, outputs=workflow_spec.template.outputs, subworkflows=workflow_spec.sub_workflows, ) return dj_spec
def test_metadata_values(): with pytest.raises(FlyteValidationException): WorkflowMetadata(on_failure=0) wm = WorkflowMetadata(on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) assert wm.on_failure == WorkflowFailurePolicy.FAIL_IMMEDIATELY
def compile_into_workflow( self, ctx: FlyteContext, is_fast_execution: bool, task_function: Callable, **kwargs ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]: if not ctx.compilation_state: cs = ctx.new_compilation_state("dynamic") else: cs = ctx.compilation_state.with_params(prefix="dynamic") with FlyteContextManager.with_context(ctx.with_compilation_state(cs)): # TODO: Resolve circular import from flytekit.common.translator import get_serializable workflow_metadata = WorkflowMetadata( on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) defaults = WorkflowMetadataDefaults(interruptible=False) self._wf = PythonFunctionWorkflow(task_function, metadata=workflow_metadata, default_metadata=defaults) self._wf.compile(**kwargs) wf = self._wf sdk_workflow = get_serializable(OrderedDict(), ctx.serialization_settings, wf, is_fast_execution) # If no nodes were produced, let's just return the strict outputs if len(sdk_workflow.nodes) == 0: return _literal_models.LiteralMap( literals={ binding.var: binding.binding.to_literal_model() for binding in sdk_workflow._outputs }) # Gather underlying tasks/workflows that get referenced. Launch plans are handled by propeller. tasks = set() sub_workflows = set() for n in sdk_workflow.nodes: self.aggregate(tasks, sub_workflows, n) if is_fast_execution: if (not ctx.execution_state or not ctx.execution_state.additional_context or not ctx.execution_state.additional_context.get( "dynamic_addl_distro")): raise AssertionError( "Compilation for a dynamic workflow called in fast execution mode but no additional code " "distribution could be retrieved") logger.warn( f"ctx.execution_state.additional_context {ctx.execution_state.additional_context}" ) sanitized_tasks = set() for task in tasks: sanitized_args = [] for arg in task.container.args: if arg == "{{ .remote_package_path }}": sanitized_args.append( ctx.execution_state.additional_context.get( "dynamic_addl_distro")) elif arg == "{{ .dest_dir }}": sanitized_args.append( ctx.execution_state.additional_context.get( "dynamic_dest_dir", ".")) else: sanitized_args.append(arg) del task.container.args[:] task.container.args.extend(sanitized_args) sanitized_tasks.add(task) tasks = sanitized_tasks dj_spec = _dynamic_job.DynamicJobSpec( min_successes=len(sdk_workflow.nodes), tasks=list(tasks), nodes=sdk_workflow.nodes, outputs=sdk_workflow._outputs, subworkflows=list(sub_workflows), ) return dj_spec
def compile_into_workflow( self, ctx: FlyteContext, task_function: Callable, **kwargs ) -> Union[_dynamic_job.DynamicJobSpec, _literal_models.LiteralMap]: """ In the case of dynamic workflows, this function will produce a workflow definition at execution time which will then proceed to be executed. """ # TODO: circular import from flytekit.core.task import ReferenceTask if not ctx.compilation_state: cs = ctx.new_compilation_state(prefix="d") else: cs = ctx.compilation_state.with_params(prefix="d") with FlyteContextManager.with_context(ctx.with_compilation_state(cs)): # TODO: Resolve circular import from flytekit.tools.translator import get_serializable workflow_metadata = WorkflowMetadata( on_failure=WorkflowFailurePolicy.FAIL_IMMEDIATELY) defaults = WorkflowMetadataDefaults( interruptible=self.metadata.interruptible if self.metadata. interruptible is not None else False) self._wf = PythonFunctionWorkflow(task_function, metadata=workflow_metadata, default_metadata=defaults) self._wf.compile(**kwargs) wf = self._wf model_entities = OrderedDict() # See comment on reference entity checking a bit down below in this function. # This is the only circular dependency between the translator.py module and the rest of the flytekit # authoring experience. workflow_spec: admin_workflow_models.WorkflowSpec = get_serializable( model_entities, ctx.serialization_settings, wf) # If no nodes were produced, let's just return the strict outputs if len(workflow_spec.template.nodes) == 0: return _literal_models.LiteralMap( literals={ binding.var: binding.binding.to_literal_model() for binding in workflow_spec.template.outputs }) # Gather underlying TaskTemplates that get referenced. tts = [] for entity, model in model_entities.items(): # We only care about gathering tasks here. Launch plans are handled by # propeller. Subworkflows should already be in the workflow spec. if not isinstance(entity, Task) and not isinstance( entity, task_models.TaskTemplate): continue # Handle FlyteTask if isinstance(entity, task_models.TaskTemplate): tts.append(entity) continue # We are currently not supporting reference tasks since these will # require a network call to flyteadmin to populate the TaskTemplate # model if isinstance(entity, ReferenceTask): raise Exception( "Reference tasks are currently unsupported within dynamic tasks" ) if not isinstance(model, task_models.TaskSpec): raise TypeError( f"Unexpected type for serialized form of task. Expected {task_models.TaskSpec}, but got {type(model)}" ) # Store the valid task template so that we can pass it to the # DynamicJobSpec later tts.append(model.template) dj_spec = _dynamic_job.DynamicJobSpec( min_successes=len(workflow_spec.template.nodes), tasks=tts, nodes=workflow_spec.template.nodes, outputs=workflow_spec.template.outputs, subworkflows=workflow_spec.sub_workflows, ) return dj_spec