async def _write_step_inputs(wf_storage: workflow_storage.WorkflowStorage, step_id: StepID, inputs: WorkflowData) -> None: """Save workflow inputs.""" metadata = inputs.to_metadata() with serialization_context.workflow_args_keeping_context(): # TODO(suquark): in the future we should write to storage directly # with plasma store object in memory. args_obj = ray.get(inputs.inputs.args) workflow_id = wf_storage._workflow_id storage = wf_storage._storage save_tasks = [ # TODO (Alex): Handle the json case better? wf_storage._put(wf_storage._key_step_input_metadata(step_id), metadata, True), wf_storage._put(wf_storage._key_step_user_metadata(step_id), inputs.user_metadata, True), serialization.dump_to_storage( wf_storage._key_step_function_body(step_id), inputs.func_body, workflow_id, storage, ), serialization.dump_to_storage(wf_storage._key_step_args(step_id), args_obj, workflow_id, storage), ] await asyncio.gather(*save_tasks)
def get_value(self, index: int, is_json: bool) -> Any: path = self._log_dir / f"{index}.value" if is_json: with open(path) as f: return json.load(f) else: with open(path, "rb") as f: with serialization_context.workflow_args_keeping_context(): return ray.cloudpickle.load(f)
def load_step_args(self, task_id: TaskID) -> ray.ObjectRef: """Load the input arguments of the workflow step. This must be done under a serialization context, otherwise the arguments would not be reconstructed successfully. Args: task_id: ID of the workflow step. Returns: An object ref of the input args. """ with serialization_context.workflow_args_keeping_context(): x = self._get(self._key_step_args(task_id)) return ray.put(x)
def save_workflow_execution_state(self, creator_task_id: TaskID, state: WorkflowExecutionState) -> None: """Save a workflow execution state. Typically, the state is translated from a Ray DAG. Args: creator_task_id: The ID of the task that creates the state. state: The state converted from the DAG. """ assert creator_task_id != state.output_task_id for task_id, task in state.tasks.items(): # TODO (Alex): Handle the json case better? metadata = { **task.to_dict(), "workflow_refs": state.upstream_dependencies[task_id], } self._put(self._key_step_input_metadata(task_id), metadata, True) # TODO(suquark): The task user metadata duplicates. self._put( self._key_step_user_metadata(task_id), task.user_metadata, True, ) workflow_id = self._workflow_id serialization.dump_to_storage( self._key_step_function_body(task_id), task.func_body, workflow_id, self, ) with serialization_context.workflow_args_keeping_context(): # TODO(suquark): in the future we should write to storage directly # with plasma store object in memory. args_obj = ray.get(state.task_input_args[task_id]) serialization.dump_to_storage( self._key_step_args(task_id), args_obj, workflow_id, self, ) # Finally, point to the output ID of the DAG. The DAG is a continuation # of the creator task. self._put( self._key_step_output_metadata(creator_task_id), {"output_step_id": state.output_task_id}, True, )