def workflow_state_from_storage( workflow_id: str, task_id: Optional[TaskID]) -> WorkflowExecutionState: """Try to construct a workflow (step) that recovers the workflow step. If the workflow step already has an output checkpointing file, we return the workflow step id instead. Args: workflow_id: The ID of the workflow. task_id: The ID of the output task. If None, it will be the entrypoint of the workflow. Returns: A workflow that recovers the step, or the output of the step if it has been checkpointed. """ reader = workflow_storage.WorkflowStorage(workflow_id) if task_id is None: task_id = reader.get_entrypoint_step_id() # Construct the workflow execution state. state = WorkflowExecutionState(output_task_id=task_id) state.output_task_id = task_id visited_tasks = set() dag_visit_queue = deque([task_id]) with serialization.objectref_cache(): while dag_visit_queue: task_id: TaskID = dag_visit_queue.popleft() if task_id in visited_tasks: continue visited_tasks.add(task_id) r = reader.inspect_step(task_id) if not r.is_recoverable(): raise WorkflowStepNotRecoverableError(task_id) if r.output_object_valid: target = state.continuation_root.get(task_id, task_id) state.checkpoint_map[target] = WorkflowRef(task_id) continue if isinstance(r.output_step_id, str): # no input dependencies here because the task has already # returned a continuation state.upstream_dependencies[task_id] = [] state.append_continuation(task_id, r.output_step_id) dag_visit_queue.append(r.output_step_id) continue # transfer task info to state state.add_dependencies(task_id, r.workflow_refs) state.task_input_args[task_id] = reader.load_step_args(task_id) # TODO(suquark): although not necessary, but for completeness, # we may also load name and metadata. state.tasks[task_id] = Task( name="", options=r.step_options, user_metadata={}, func_body=reader.load_step_func_body(task_id), ) dag_visit_queue.extend(r.workflow_refs) return state
def _construct_resume_workflow_from_step( reader: workflow_storage.WorkflowStorage, step_id: StepID, input_map: Dict[StepID, Any]) -> Union[Workflow, StepID]: """Try to construct a workflow (step) that recovers the workflow step. If the workflow step already has an output checkpointing file, we return the workflow step id instead. Args: reader: The storage reader for inspecting the step. step_id: The ID of the step we want to recover. input_map: This is a context storing the input which has been loaded. This context is important for dedupe Returns: A workflow that recovers the step, or a ID of a step that contains the output checkpoint file. """ result: workflow_storage.StepInspectResult = reader.inspect_step(step_id) if result.output_object_valid: # we already have the output return step_id if isinstance(result.output_step_id, str): return _construct_resume_workflow_from_step(reader, result.output_step_id, input_map) # output does not exists or not valid. try to reconstruct it. if not result.is_recoverable(): raise WorkflowStepNotRecoverableError(step_id) with serialization.objectref_cache(): input_workflows = [] for i, _step_id in enumerate(result.workflows): # Check whether the step has been loaded or not to avoid # duplication if _step_id in input_map: r = input_map[_step_id] else: r = _construct_resume_workflow_from_step( reader, _step_id, input_map) input_map[_step_id] = r if isinstance(r, Workflow): input_workflows.append(r) else: assert isinstance(r, StepID) # TODO (Alex): We should consider caching these outputs too. input_workflows.append(reader.load_step_output(r)) workflow_refs = list(map(WorkflowRef, result.workflow_refs)) args, kwargs = reader.load_step_args(step_id, input_workflows, workflow_refs) recovery_workflow: Workflow = _recover_workflow_step.options( max_retries=result.max_retries, catch_exceptions=result.catch_exceptions, **result.ray_options).step(args, kwargs, input_workflows, workflow_refs) recovery_workflow._step_id = step_id recovery_workflow.data.step_type = result.step_type return recovery_workflow
def _construct_resume_workflow_from_step( workflow_id: str, step_id: StepID) -> Union[Workflow, Any]: """Try to construct a workflow (step) that recovers the workflow step. If the workflow step already has an output checkpointing file, we return the workflow step id instead. Args: workflow_id: The ID of the workflow. step_id: The ID of the step we want to recover. Returns: A workflow that recovers the step, or the output of the step if it has been checkpointed. """ reader = workflow_storage.WorkflowStorage(workflow_id) # Step 1: construct dependency of the DAG (BFS) inpsect_results = {} dependency_map = defaultdict(list) num_in_edges = {} dag_visit_queue = deque([step_id]) while dag_visit_queue: s: StepID = dag_visit_queue.popleft() if s in inpsect_results: continue r = reader.inspect_step(s) inpsect_results[s] = r if not r.is_recoverable(): raise WorkflowStepNotRecoverableError(s) if r.output_object_valid: deps = [] elif isinstance(r.output_step_id, str): deps = [r.output_step_id] else: deps = r.workflows for w in deps: dependency_map[w].append(s) num_in_edges[s] = len(deps) dag_visit_queue.extend(deps) # Step 2: topological sort to determine the execution order (Kahn's algorithm) execution_queue: List[StepID] = [] start_nodes = deque(k for k, v in num_in_edges.items() if v == 0) while start_nodes: n = start_nodes.popleft() execution_queue.append(n) for m in dependency_map[n]: num_in_edges[m] -= 1 assert num_in_edges[m] >= 0, (m, n) if num_in_edges[m] == 0: start_nodes.append(m) # Step 3: recover the workflow by the order of the execution queue with serialization.objectref_cache(): # "input_map" is a context storing the input which has been loaded. # This context is important for deduplicate step inputs. input_map: Dict[StepID, Any] = {} for _step_id in execution_queue: result = inpsect_results[_step_id] if result.output_object_valid: input_map[_step_id] = reader.load_step_output(_step_id) continue if isinstance(result.output_step_id, str): input_map[_step_id] = input_map[result.output_step_id] continue # Process the wait step as a special case. if result.step_options.step_type == StepType.WAIT: wait_input_workflows = [] for w in result.workflows: output = input_map[w] if isinstance(output, Workflow): wait_input_workflows.append(output) else: # Simulate a workflow with a workflow reference so it could be # used directly by 'workflow.wait'. static_ref = WorkflowStaticRef(step_id=w, ref=ray.put(output)) wait_input_workflows.append( Workflow.from_ref(static_ref)) recovery_workflow = ray.workflow.wait( wait_input_workflows, **result.step_options.ray_options.get("wait_options", {}), ) else: args, kwargs = reader.load_step_args( _step_id, workflows=[input_map[w] for w in result.workflows], workflow_refs=list(map(WorkflowRef, result.workflow_refs)), ) func: Callable = reader.load_step_func_body(_step_id) # TODO(suquark): Use an alternative function when "workflow.step" # is fully deprecated. recovery_workflow = ray.workflow.step(func).step( *args, **kwargs) # override step_options recovery_workflow._step_id = _step_id recovery_workflow.data.step_options = result.step_options input_map[_step_id] = recovery_workflow # Step 4: return the output of the requested step return input_map[step_id]
def _construct_resume_workflow_from_step( reader: workflow_storage.WorkflowStorage, step_id: StepID, input_map: Dict[StepID, Any], ) -> Union[Workflow, StepID]: """Try to construct a workflow (step) that recovers the workflow step. If the workflow step already has an output checkpointing file, we return the workflow step id instead. Args: reader: The storage reader for inspecting the step. step_id: The ID of the step we want to recover. input_map: This is a context storing the input which has been loaded. This context is important for dedupe Returns: A workflow that recovers the step, or a ID of a step that contains the output checkpoint file. """ result: workflow_storage.StepInspectResult = reader.inspect_step(step_id) if result.output_object_valid: # we already have the output return step_id if isinstance(result.output_step_id, str): return _construct_resume_workflow_from_step( reader, result.output_step_id, input_map ) # output does not exists or not valid. try to reconstruct it. if not result.is_recoverable(): raise WorkflowStepNotRecoverableError(step_id) step_options = result.step_options # Process the wait step as a special case. if step_options.step_type == StepType.WAIT: return _reconstruct_wait_step(reader, step_id, result, input_map) with serialization.objectref_cache(): input_workflows = [] for i, _step_id in enumerate(result.workflows): # Check whether the step has been loaded or not to avoid # duplication if _step_id in input_map: r = input_map[_step_id] else: r = _construct_resume_workflow_from_step(reader, _step_id, input_map) input_map[_step_id] = r if isinstance(r, Workflow): input_workflows.append(r) else: assert isinstance(r, StepID) # TODO (Alex): We should consider caching these outputs too. input_workflows.append(reader.load_step_output(r)) workflow_refs = list(map(WorkflowRef, result.workflow_refs)) args, kwargs = reader.load_step_args(step_id, input_workflows, workflow_refs) # Note: we must uppack args and kwargs, so the refs in the args/kwargs can get # resolved consistently like in Ray. recovery_workflow: Workflow = _recover_workflow_step.step( input_workflows, workflow_refs, *args, **kwargs, ) recovery_workflow._step_id = step_id # override step_options recovery_workflow.data.step_options = step_options return recovery_workflow