def step(self, *args, **kwargs): flattened_args = signature.flatten_args(self._signature, args, kwargs) actor_id = workflow_context.get_current_workflow_id() if not self.readonly: if self._method_name == "__init__": state_ref = None else: ws = WorkflowStorage(actor_id, get_global_storage()) state_ref = WorkflowRef(ws.get_entrypoint_step_id()) # This is a hack to insert a positional argument. flattened_args = [signature.DUMMY_TYPE, state_ref] + flattened_args workflow_inputs = serialization_context.make_workflow_inputs( flattened_args) if self.readonly: _actor_method = _wrap_readonly_actor_method( actor_id, self._original_class, self._method_name) else: _actor_method = _wrap_actor_method(self._original_class, self._method_name) workflow_data = WorkflowData( func_body=_actor_method, inputs=workflow_inputs, name=self._name, step_options=self._options, user_metadata=self._user_metadata, ) wf = Workflow(workflow_data) return wf
def test_dynamic_workflow_ref(workflow_start_regular_shared): # This test also shows different "style" of running workflows. first_step = incr.step(0) assert first_step.run("test_dynamic_workflow_ref") == 1 second_step = incr.step(WorkflowRef(first_step.step_id)) # Without rerun, it'll just return the previous result assert second_step.run("test_dynamic_workflow_ref") == 1
def _submit_ray_task(self, task_id: TaskID, job_id: str) -> None: """Submit a workflow task as a Ray task.""" state = self._state baked_inputs = _BakedWorkflowInputs( args=state.task_input_args[task_id], workflow_refs=[ state.get_input(d) for d in state.upstream_dependencies[task_id] ], ) task = state.tasks[task_id] executor = get_step_executor(task.options) metadata_ref, output_ref = executor( task.func_body, state.task_context[task_id], job_id, task_id, baked_inputs, task.options, ) # The input workflow is not a reference to an executed workflow. future = asyncio.wrap_future(metadata_ref.future()) future.add_done_callback(self._completion_queue.put_nowait) state.insert_running_frontier(future, WorkflowRef(task_id, ref=output_ref)) state.task_execution_metadata[task_id] = TaskExecutionMetadata( submit_time=time.time())
def workflow_state_from_storage( workflow_id: str, task_id: Optional[TaskID]) -> WorkflowExecutionState: """Try to construct a workflow (step) that recovers the workflow step. If the workflow step already has an output checkpointing file, we return the workflow step id instead. Args: workflow_id: The ID of the workflow. task_id: The ID of the output task. If None, it will be the entrypoint of the workflow. Returns: A workflow that recovers the step, or the output of the step if it has been checkpointed. """ reader = workflow_storage.WorkflowStorage(workflow_id) if task_id is None: task_id = reader.get_entrypoint_step_id() # Construct the workflow execution state. state = WorkflowExecutionState(output_task_id=task_id) state.output_task_id = task_id visited_tasks = set() dag_visit_queue = deque([task_id]) with serialization.objectref_cache(): while dag_visit_queue: task_id: TaskID = dag_visit_queue.popleft() if task_id in visited_tasks: continue visited_tasks.add(task_id) r = reader.inspect_step(task_id) if not r.is_recoverable(): raise WorkflowStepNotRecoverableError(task_id) if r.output_object_valid: target = state.continuation_root.get(task_id, task_id) state.checkpoint_map[target] = WorkflowRef(task_id) continue if isinstance(r.output_step_id, str): # no input dependencies here because the task has already # returned a continuation state.upstream_dependencies[task_id] = [] state.append_continuation(task_id, r.output_step_id) dag_visit_queue.append(r.output_step_id) continue # transfer task info to state state.add_dependencies(task_id, r.workflow_refs) state.task_input_args[task_id] = reader.load_step_args(task_id) # TODO(suquark): although not necessary, but for completeness, # we may also load name and metadata. state.tasks[task_id] = Task( name="", options=r.step_options, user_metadata={}, func_body=reader.load_step_func_body(task_id), ) dag_visit_queue.extend(r.workflow_refs) return state
def step(method_name, method, *args, **kwargs): readonly = getattr(method, "__virtual_actor_readonly__", False) flattened_args = self.flatten_args(method_name, args, kwargs) actor_id = workflow_context.get_current_workflow_id() if not readonly: if method_name == "__init__": state_ref = None else: ws = WorkflowStorage(actor_id, get_global_storage()) state_ref = WorkflowRef(ws.get_entrypoint_step_id()) # This is a hack to insert a positional argument. flattened_args = [signature.DUMMY_TYPE, state_ref ] + flattened_args workflow_inputs = serialization_context.make_workflow_inputs( flattened_args) if readonly: _actor_method = _wrap_readonly_actor_method( actor_id, self.cls, method_name) step_type = StepType.READONLY_ACTOR_METHOD else: _actor_method = _wrap_actor_method(self.cls, method_name) step_type = StepType.ACTOR_METHOD # TODO(suquark): Support actor options. workflow_data = WorkflowData( func_body=_actor_method, step_type=step_type, inputs=workflow_inputs, max_retries=1, catch_exceptions=False, ray_options={}, name=None, ) wf = Workflow(workflow_data) return wf
def test_dynamic_workflow_ref(workflow_start_regular_shared): @ray.remote def incr(x): return x + 1 # This test also shows different "style" of running workflows. first_step = workflow.create(incr.bind(0)) assert first_step.run("test_dynamic_workflow_ref") == 1 second_step = workflow.create(incr.bind(WorkflowRef(first_step.step_id))) # Without rerun, it'll just return the previous result assert second_step.run("test_dynamic_workflow_ref") == 1
def test_dynamic_workflow_ref(workflow_start_regular_shared): @ray.remote def incr(x): return x + 1 # This test also shows different "style" of running workflows. assert workflow.run(incr.bind(0), workflow_id="test_dynamic_workflow_ref") == 1 # Without rerun, it'll just return the previous result assert ( workflow.run( incr.bind(WorkflowRef("incr")), workflow_id="test_dynamic_workflow_ref" ) == 1 )
async def _post_process_ready_task( self, task_id: TaskID, metadata: WorkflowExecutionMetadata, output_ref: WorkflowRef, ) -> None: state = self._state state.task_retries.pop(task_id, None) if metadata.is_output_workflow: # The task returns a continuation sub_workflow_state: WorkflowExecutionState = await output_ref.ref # init the context just for "sub_workflow_state" sub_workflow_state.init_context(state.task_context[task_id]) state.merge_state(sub_workflow_state) # build up runtime dependency continuation_task_id = sub_workflow_state.output_task_id state.append_continuation(task_id, continuation_task_id) # Migrate callbacks - all continuation callbacks are moved # under the root of continuation, so when the continuation # completes, all callbacks in the continuation can be triggered. if continuation_task_id in self._task_done_callbacks: self._task_done_callbacks[ state.continuation_root[continuation_task_id]].extend( self._task_done_callbacks.pop(continuation_task_id)) state.construct_scheduling_plan(sub_workflow_state.output_task_id) else: # The task returns a normal object target_task_id = state.continuation_root.get(task_id, task_id) state.output_map[target_task_id] = output_ref if state.tasks[task_id].options.checkpoint: state.checkpoint_map[target_task_id] = WorkflowRef(task_id) state.done_tasks.add(target_task_id) # TODO(suquark): cleanup callbacks when a result is set? if target_task_id in self._task_done_callbacks: for callback in self._task_done_callbacks[target_task_id]: callback.set_result(output_ref) for m in state.reference_set[target_task_id]: # we ensure that each reference corresponds to a pending input state.pending_input_set[m].remove(target_task_id) if not state.pending_input_set[m]: state.append_frontier_to_run(m)
def _node_visitor(node: Any) -> Any: if isinstance(node, FunctionNode): bound_options = node._bound_options.copy() num_returns = bound_options.get("num_returns", 1) if num_returns is None: # ray could use `None` as default value num_returns = 1 if num_returns > 1: raise ValueError("Workflow steps can only have one return.") workflow_options = bound_options.pop("_metadata", {}).get(WORKFLOW_OPTIONS, {}) # If checkpoint option is not specified, inherit checkpoint # options from context (i.e. checkpoint options of the outer # step). If it is still not specified, it's True by default. checkpoint = workflow_options.get("checkpoint", None) if checkpoint is None: checkpoint = context.checkpoint if context is not None else True # When it returns a nested workflow, catch_exception # should be passed recursively. catch_exceptions = workflow_options.get("catch_exceptions", None) if catch_exceptions is None: # TODO(suquark): should we also handle exceptions from a "leaf node" # in the continuation? For example, we have a workflow # > @ray.remote # > def A(): pass # > @ray.remote # > def B(x): return x # > @ray.remote # > def C(x): return workflow.continuation(B.bind(A.bind())) # > dag = C.options(**workflow.options(catch_exceptions=True)).bind() # Should C catches exceptions of A? if node.get_stable_uuid() == dag_node.get_stable_uuid(): # 'catch_exception' context should be passed down to # its direct continuation task. # In this case, the direct continuation is the output node. catch_exceptions = (context.catch_exceptions if context is not None else False) else: catch_exceptions = False max_retries = bound_options.get("max_retries", 3) if not isinstance(max_retries, int) or max_retries < -1: raise ValueError( "'max_retries' only accepts 0, -1 or a positive integer.") step_options = WorkflowStepRuntimeOptions( step_type=StepType.FUNCTION, catch_exceptions=catch_exceptions, max_retries=max_retries, allow_inplace=False, checkpoint=checkpoint, ray_options=bound_options, ) workflow_refs: List[WorkflowRef] = [] with serialization_context.workflow_args_serialization_context( workflow_refs): _func_signature = signature.extract_signature(node._body) flattened_args = signature.flatten_args( _func_signature, node._bound_args, node._bound_kwargs) # NOTE: When calling 'ray.put', we trigger python object # serialization. Under our serialization context, # Workflows are separated from the arguments, # leaving a placeholder object with all other python objects. # Then we put the placeholder object to object store, # so it won't be mutated later. This guarantees correct # semantics. See "tests/test_variable_mutable.py" as # an example. input_placeholder: ray.ObjectRef = ray.put(flattened_args) name = workflow_options.get("name") if name is None: name = f"{get_module(node._body)}.{slugify(get_qualname(node._body))}" task_id = ray.get(mgr.gen_step_id.remote(workflow_id, name)) state.add_dependencies(task_id, [s.task_id for s in workflow_refs]) state.task_input_args[task_id] = input_placeholder user_metadata = workflow_options.pop("metadata", {}) validate_user_metadata(user_metadata) state.tasks[task_id] = Task( name=name, options=step_options, user_metadata=user_metadata, func_body=node._body, ) return WorkflowRef(task_id) if isinstance(node, InputAttributeNode): return node._execute_impl() # get data from input node if isinstance(node, InputNode): return input_context # replace input node with input data if not isinstance(node, DAGNode): return node # return normal objects raise TypeError(f"Unsupported DAG node: {node}")
async def _handle_ready_task(self, fut: asyncio.Future, workflow_id: str, wf_store: "WorkflowStorage") -> None: """Handle ready task, especially about its exception.""" state = self._state output_ref = state.pop_running_frontier(fut) task_id = output_ref.task_id try: metadata: WorkflowExecutionMetadata = fut.result() state.task_execution_metadata[task_id].finish_time = time.time() logger.info(f"Task status [{WorkflowStatus.SUCCESSFUL}]\t" f"[{workflow_id}@{task_id}]") await self._post_process_ready_task(task_id, metadata, output_ref) except asyncio.CancelledError: # NOTE: We must update the workflow status before broadcasting # the exception. Otherwise, the workflow status would still be # 'RUNNING' if check the status immediately after cancellation. wf_store.update_workflow_status(WorkflowStatus.CANCELED) logger.warning(f"Workflow '{workflow_id}' is cancelled.") # broadcasting cancellation to all outputs err = WorkflowCancellationError(workflow_id) self._broadcast_exception(err) raise err from None except Exception as e: is_application_error = False if isinstance(e, RayTaskError): reason = "an exception raised by the task" is_application_error = True elif isinstance(e, RayError): reason = "a system error" else: reason = "an unknown error" logger.error( f"Task status [{WorkflowStatus.FAILED}] due to {reason}.\t" f"[{workflow_id}@{task_id}]") # on error, the error is caught by this task exception_catching_task_id = None # lookup a creator task that catches the exception if is_application_error: for t, task in self._iter_callstack(task_id): if task.options.catch_exceptions: exception_catching_task_id = t break if exception_catching_task_id is None: # NOTE: We must update the workflow status before broadcasting # the exception. Otherwise, the workflow status would still be # 'RUNNING' if check the status immediately after the exception. wf_store.update_workflow_status(WorkflowStatus.FAILED) logger.error(f"Workflow '{workflow_id}' failed due to {e}") err = WorkflowExecutionError(workflow_id) err.__cause__ = e # chain exceptions self._broadcast_exception(err) raise err logger.info( f"Exception raised by '{workflow_id}@{task_id}' is caught by " f"'{workflow_id}@{exception_catching_task_id}'") # assign output to exception catching task; # compose output with caught exception await self._post_process_ready_task( exception_catching_task_id, metadata=WorkflowExecutionMetadata(), output_ref=WorkflowRef(task_id, ray.put((None, e))), )