def commit_step( store: workflow_storage.WorkflowStorage, step_id: "StepID", ret: Union["Workflow", Any], *, exception: Optional[Exception], ): """Checkpoint the step output. Args: store: The storage the current workflow is using. step_id: The ID of the step. ret: The returned object of the workflow step. exception: The exception caught by the step. """ from ray.workflow.common import Workflow if isinstance(ret, Workflow): assert not ret.executed tasks = [] for w in ret._iter_workflows_in_dag(): # If this is a reference to a workflow, do not checkpoint # its input (again). if w.ref is None: tasks.append(_write_step_inputs(store, w.step_id, w.data)) asyncio_run(asyncio.gather(*tasks)) context = workflow_context.get_workflow_step_context() store.save_step_output(step_id, ret, exception=exception, outer_most_step_id=context.outer_most_step_id)
def save_actor_class_body(self, cls: type) -> None: """Save the class body of the virtual actor. Args: cls: The class body used by the virtual actor. Raises: DataSaveError: if we fail to save the class body. """ asyncio_run(self._put(self._key_class_body(), cls))
def save_workflow_postrun_metadata(self, metadata: Dict[str, Any]): """Save post-run metadata of the current workflow. Args: metadata: post-run metadata of the current workflow. Raises: DataSaveError: if we fail to save the post-run metadata. """ asyncio_run(self._put(self._key_workflow_postrun_metadata(), metadata, True))
def resume(num_records_replayed): key = debug_store.wrapped_storage.make_key("complex_workflow") asyncio_run(debug_store.wrapped_storage.delete_prefix(key)) async def replay(): # We need to replay one by one to avoid conflict for i in range(num_records_replayed): await debug_store.replay(i) asyncio_run(replay()) return ray.get(workflow.resume(workflow_id="complex_workflow"))
def save_step_postrun_metadata(self, step_id: StepID, metadata: Dict[str, Any]): """Save post-run metadata of the current step. Args: step_id: ID of the workflow step. metadata: post-run metadata of the current step. Raises: DataSaveError: if we fail to save the post-run metadata. """ asyncio_run(self._put(self._key_step_postrun_metadata(step_id), metadata, True))
def save_workflow_meta(self, metadata: WorkflowMetaData) -> None: """Save the metadata of the current workflow. Args: metadata: WorkflowMetaData of the current workflow. Raises: DataSaveError: if we fail to save the class body. """ metadata = { "status": metadata.status.value, } asyncio_run(self._put(self._key_workflow_metadata(), metadata, True))
def load_step_output(self, step_id: StepID) -> Any: """Load the output of the workflow step from checkpoint. Args: step_id: ID of the workflow step. Returns: Output of the workflow step. """ tasks = [ self._get(self._key_step_output(step_id), no_exception=True), self._get(self._key_step_exception(step_id), no_exception=True), ] ((output_ret, output_err), (exception_ret, exception_err)) = asyncio_run( asyncio.gather(*tasks) ) # When we have output, always return output first if output_err is None: return output_ret # When we don't have output, check exception if exception_err is None: raise exception_ret # In this case, there is no such step raise output_err
def load_step_args( self, step_id: StepID, workflows: List[Any], workflow_refs: List[WorkflowRef] ) -> Tuple[List, Dict[str, Any]]: """Load the input arguments of the workflow step. This must be done under a serialization context, otherwise the arguments would not be reconstructed successfully. Args: step_id: ID of the workflow step. workflows: The workflows in the original arguments, replaced by the actual workflow outputs. object_refs: The object refs in the original arguments. Returns: Args and kwargs. """ with serialization_context.workflow_args_resolving_context( workflows, workflow_refs ): flattened_args = asyncio_run(self._get(self._key_step_args(step_id))) # dereference arguments like Ray remote functions flattened_args = [ ray.get(a) if isinstance(a, ray.ObjectRef) else a for a in flattened_args ] return signature.recover_args(flattened_args)
def load_actor_class_body(self) -> type: """Load the class body of the virtual actor. Raises: DataLoadError: if we fail to load the class body. """ return asyncio_run(self._get(self._key_class_body()))
def save_object_ref(self, obj_ref: ray.ObjectRef) -> None: """Save the object ref. Args: obj_ref: The object reference Returns: None """ return asyncio_run(self._save_object_ref(obj_ref))
def load_step_func_body(self, step_id: StepID) -> Callable: """Load the function body of the workflow step. Args: step_id: ID of the workflow step. Returns: A callable function. """ return asyncio_run(self._get(self._key_step_function_body(step_id)))
def advance_progress(self, finished_step_id: "StepID") -> None: """Save the latest progress of a workflow. This is used by a virtual actor. Args: finished_step_id: The step that contains the latest output. Raises: DataSaveError: if we fail to save the progress. """ asyncio_run( self._put( self._key_workflow_progress(), { "step_id": finished_step_id, }, True, ) )
def get_latest_progress(self) -> "StepID": """Load the latest progress of a workflow. This is used by a virtual actor. Raises: DataLoadError: if we fail to load the progress. Returns: The step that contains the latest output. """ return asyncio_run(self._get(self._key_workflow_progress(), True))["step_id"]
def delete_workflow(self): prefix = self._storage.make_key(self._workflow_id) scan = [] scan_future = self._storage.scan_prefix(prefix) delete_future = self._storage.delete_prefix(prefix) try: # TODO (Alex): There's a race condition here if someone tries to # start the workflow between thesea ops. scan = asyncio_run(scan_future) asyncio_run(delete_future) except FileNotFoundError: # TODO (Alex): Different file systems seem to have different # behavior when deleting a prefix that doesn't exist, so we may # need to catch a broader class of exceptions. pass if not scan: raise WorkflowNotFoundError(self._workflow_id)
def gen_step_id(self, step_name: str) -> int: async def _gen_step_id(): key = self._key_num_steps_with_name(step_name) try: val = await self._get(key, True) await self._put(key, val + 1, True) return val + 1 except KeyNotFoundError: await self._put(key, 0, True) return 0 return asyncio_run(_gen_step_id())
def inspect_step(self, step_id: StepID) -> StepInspectResult: """ Get the status of a workflow step. The status indicates whether the workflow step can be recovered etc. Args: step_id: The ID of a workflow step Returns: The status of the step. """ return asyncio_run(self._inspect_step(step_id))
def load_workflow_meta(self) -> Optional[WorkflowMetaData]: """Load the metadata of the current workflow. Returns: The metadata of the current workflow. If it doesn't exist, return None. """ try: metadata = asyncio_run(self._get(self._key_workflow_metadata(), True)) return WorkflowMetaData(status=WorkflowStatus(metadata["status"])) except KeyNotFoundError: return None
def get_entrypoint_step_id(self) -> StepID: """Load the entrypoint step ID of the workflow. Returns: The ID of the entrypoint step. """ # empty StepID represents the workflow driver try: return asyncio_run(self._locate_output_step_id("")) except Exception as e: raise ValueError( "Fail to get entrypoint step ID from workflow" f"[id={self._workflow_id}]" ) from e
def _put_helper(identifier: str, obj: Any, workflow_id: str, storage: storage.Storage) -> None: # TODO (Alex): This check isn't sufficient, it only works for directly # nested object refs. if isinstance(obj, ray.ObjectRef): raise NotImplementedError( "Workflow does not support checkpointing nested object references yet." ) paths = obj_id_to_paths(workflow_id, identifier) promise = dump_to_storage(paths, obj, workflow_id, storage, update_existing=False) return common.asyncio_run(promise)
def load_object_ref(self, object_id: str) -> ray.ObjectRef: """Load the input object ref. Args: object_id: The hex ObjectID. Returns: The object ref. """ async def _load_obj_ref() -> ray.ObjectRef: data = await self._get(self._key_obj_id(object_id)) ref = _put_obj_ref.remote((data,)) return ref return asyncio_run(_load_obj_ref())
def load_step_metadata(self, step_id: StepID) -> Dict[str, Any]: """Load the metadata of the given step. Returns: The metadata of the given step. """ async def _load_step_metadata(): if not await self._scan([self._workflow_id, "steps", step_id]): if not await self._scan([self._workflow_id]): raise ValueError("No such workflow_id {}".format(self._workflow_id)) else: raise ValueError( "No such step_id {} in workflow {}".format( step_id, self._workflow_id ) ) tasks = [ self._get(self._key_step_input_metadata(step_id), True, True), self._get(self._key_step_prerun_metadata(step_id), True, True), self._get(self._key_step_postrun_metadata(step_id), True, True), ] ( (input_metadata, _), (prerun_metadata, _), (postrun_metadata, _), ) = await asyncio.gather(*tasks) input_metadata = input_metadata or {} prerun_metadata = prerun_metadata or {} postrun_metadata = postrun_metadata or {} metadata = input_metadata metadata["stats"] = {} metadata["stats"].update(prerun_metadata) metadata["stats"].update(postrun_metadata) return metadata return asyncio_run(_load_step_metadata())
def load_workflow_metadata(self) -> Dict[str, Any]: """Load the metadata of the current workflow. Returns: The metadata of the current workflow. """ async def _load_workflow_metadata(): if not await self._scan([self._workflow_id]): raise ValueError("No such workflow_id {}".format(self._workflow_id)) tasks = [ self._get(self._key_workflow_metadata(), True, True), self._get(self._key_workflow_user_metadata(), True, True), self._get(self._key_workflow_prerun_metadata(), True, True), self._get(self._key_workflow_postrun_metadata(), True, True), ] ( (status_metadata, _), (user_metadata, _), (prerun_metadata, _), (postrun_metadata, _), ) = await asyncio.gather(*tasks) status_metadata = status_metadata or {} user_metadata = user_metadata or {} prerun_metadata = prerun_metadata or {} postrun_metadata = postrun_metadata or {} metadata = status_metadata metadata["user_metadata"] = user_metadata metadata["stats"] = {} metadata["stats"].update(prerun_metadata) metadata["stats"].update(postrun_metadata) return metadata return asyncio_run(_load_workflow_metadata())
def resume_all(with_failed: bool) -> List[Tuple[str, ray.ObjectRef]]: filter_set = {WorkflowStatus.RESUMABLE} if with_failed: filter_set.add(WorkflowStatus.FAILED) all_failed = list_all(filter_set) try: workflow_manager = get_management_actor() except Exception as e: raise RuntimeError("Failed to get management actor") from e async def _resume_one(wid: str) -> Tuple[str, Optional[ray.ObjectRef]]: try: job_id = ray.get_runtime_context().job_id.hex() result: "WorkflowExecutionResult" = ( await workflow_manager.run_or_resume.remote(job_id, wid)) obj = flatten_workflow_output(wid, result.persisted_output) return wid, obj except Exception: logger.error(f"Failed to resume workflow {wid}") return (wid, None) ret = asyncio_run( asyncio.gather(*[_resume_one(wid) for (wid, _) in all_failed])) return [(wid, obj) for (wid, obj) in ret if obj is not None]
def save_step_output( self, step_id: StepID, ret: Union[Workflow, Any], *, exception: Optional[Exception], outer_most_step_id: StepID, ) -> None: """When a workflow step returns, 1. If the returned object is a workflow, this means we are a nested workflow. We save the output metadata that points to the workflow. 2. Otherwise, checkpoint the output. Args: step_id: The ID of the workflow step. If it is an empty string, it means we are in the workflow job driver process. ret: The returned object from a workflow step. exception: This step should throw exception. outer_most_step_id: See WorkflowStepContext. """ tasks = [] dynamic_output_id = None if isinstance(ret, Workflow): # This workflow step returns a nested workflow. assert step_id != ret.step_id assert exception is None tasks.append( self._put( self._key_step_output_metadata(step_id), {"output_step_id": ret.step_id}, True, ) ) dynamic_output_id = ret.step_id else: if exception is None: # This workflow step returns a object. ret = ray.get(ret) if isinstance(ret, ray.ObjectRef) else ret promise = serialization.dump_to_storage( self._key_step_output(step_id), ret, self._workflow_id, self._storage, ) tasks.append(promise) # tasks.append(self._put(self._key_step_output(step_id), ret)) dynamic_output_id = step_id # TODO (yic): Delete exception file else: assert ret is None promise = serialization.dump_to_storage( self._key_step_exception(step_id), exception, self._workflow_id, self._storage, ) tasks.append(promise) # tasks.append( # self._put(self._key_step_exception(step_id), exception)) # Finish checkpointing. asyncio_run(asyncio.gather(*tasks)) # NOTE: if we update the dynamic output before # finishing checkpointing, then during recovery, the dynamic could # would point to a checkpoint that does not exist. if dynamic_output_id is not None: asyncio_run( self._update_dynamic_output(outer_most_step_id, dynamic_output_id) )
def _load_ref_helper(key: str, storage: storage.Storage): # TODO(Alex): We should stream the data directly into `cloudpickle.load`. serialized = common.asyncio_run(storage.get(key)) return cloudpickle.loads(serialized)
def message_committed(event_listener_type: EventListenerType, event: Event) -> Event: event_listener = event_listener_type() asyncio_run(event_listener.event_checkpointed(event)) return event
def get_message(event_listener_type: EventListenerType, *args, **kwargs) -> Event: event_listener = event_listener_type() return asyncio_run(event_listener.poll_for_event(*args, **kwargs))
def list_workflow(self) -> List[Tuple[str, WorkflowStatus]]: return asyncio_run(self._list_workflow())