Esempio n. 1
0
    def step(self, *args, **kwargs):
        flattened_args = signature.flatten_args(self._signature, args, kwargs)
        actor_id = workflow_context.get_current_workflow_id()
        if not self.readonly:
            if self._method_name == "__init__":
                state_ref = None
            else:
                ws = WorkflowStorage(actor_id, get_global_storage())
                state_ref = WorkflowRef(ws.get_entrypoint_step_id())
            # This is a hack to insert a positional argument.
            flattened_args = [signature.DUMMY_TYPE, state_ref] + flattened_args
        workflow_inputs = serialization_context.make_workflow_inputs(
            flattened_args)

        if self.readonly:
            _actor_method = _wrap_readonly_actor_method(
                actor_id, self._original_class, self._method_name)
        else:
            _actor_method = _wrap_actor_method(self._original_class,
                                               self._method_name)
        workflow_data = WorkflowData(
            func_body=_actor_method,
            inputs=workflow_inputs,
            name=self._name,
            step_options=self._options,
            user_metadata=self._user_metadata,
        )
        wf = Workflow(workflow_data)
        return wf
def test_dynamic_workflow_ref(workflow_start_regular_shared):
    # This test also shows different "style" of running workflows.
    first_step = incr.step(0)
    assert first_step.run("test_dynamic_workflow_ref") == 1
    second_step = incr.step(WorkflowRef(first_step.step_id))
    # Without rerun, it'll just return the previous result
    assert second_step.run("test_dynamic_workflow_ref") == 1
Esempio n. 3
0
    def _submit_ray_task(self, task_id: TaskID, job_id: str) -> None:
        """Submit a workflow task as a Ray task."""
        state = self._state
        baked_inputs = _BakedWorkflowInputs(
            args=state.task_input_args[task_id],
            workflow_refs=[
                state.get_input(d)
                for d in state.upstream_dependencies[task_id]
            ],
        )
        task = state.tasks[task_id]
        executor = get_step_executor(task.options)
        metadata_ref, output_ref = executor(
            task.func_body,
            state.task_context[task_id],
            job_id,
            task_id,
            baked_inputs,
            task.options,
        )
        # The input workflow is not a reference to an executed workflow.
        future = asyncio.wrap_future(metadata_ref.future())
        future.add_done_callback(self._completion_queue.put_nowait)

        state.insert_running_frontier(future,
                                      WorkflowRef(task_id, ref=output_ref))
        state.task_execution_metadata[task_id] = TaskExecutionMetadata(
            submit_time=time.time())
Esempio n. 4
0
def workflow_state_from_storage(
        workflow_id: str, task_id: Optional[TaskID]) -> WorkflowExecutionState:
    """Try to construct a workflow (step) that recovers the workflow step.
    If the workflow step already has an output checkpointing file, we return
    the workflow step id instead.

    Args:
        workflow_id: The ID of the workflow.
        task_id: The ID of the output task. If None, it will be the entrypoint of
            the workflow.

    Returns:
        A workflow that recovers the step, or the output of the step
            if it has been checkpointed.
    """
    reader = workflow_storage.WorkflowStorage(workflow_id)
    if task_id is None:
        task_id = reader.get_entrypoint_step_id()

    # Construct the workflow execution state.
    state = WorkflowExecutionState(output_task_id=task_id)
    state.output_task_id = task_id

    visited_tasks = set()
    dag_visit_queue = deque([task_id])
    with serialization.objectref_cache():
        while dag_visit_queue:
            task_id: TaskID = dag_visit_queue.popleft()
            if task_id in visited_tasks:
                continue
            visited_tasks.add(task_id)
            r = reader.inspect_step(task_id)
            if not r.is_recoverable():
                raise WorkflowStepNotRecoverableError(task_id)
            if r.output_object_valid:
                target = state.continuation_root.get(task_id, task_id)
                state.checkpoint_map[target] = WorkflowRef(task_id)
                continue
            if isinstance(r.output_step_id, str):
                # no input dependencies here because the task has already
                # returned a continuation
                state.upstream_dependencies[task_id] = []
                state.append_continuation(task_id, r.output_step_id)
                dag_visit_queue.append(r.output_step_id)
                continue
            # transfer task info to state
            state.add_dependencies(task_id, r.workflow_refs)
            state.task_input_args[task_id] = reader.load_step_args(task_id)
            # TODO(suquark): although not necessary, but for completeness,
            #  we may also load name and metadata.
            state.tasks[task_id] = Task(
                name="",
                options=r.step_options,
                user_metadata={},
                func_body=reader.load_step_func_body(task_id),
            )

            dag_visit_queue.extend(r.workflow_refs)

    return state
Esempio n. 5
0
            def step(method_name, method, *args, **kwargs):
                readonly = getattr(method, "__virtual_actor_readonly__", False)
                flattened_args = self.flatten_args(method_name, args, kwargs)
                actor_id = workflow_context.get_current_workflow_id()
                if not readonly:
                    if method_name == "__init__":
                        state_ref = None
                    else:
                        ws = WorkflowStorage(actor_id, get_global_storage())
                        state_ref = WorkflowRef(ws.get_entrypoint_step_id())
                    # This is a hack to insert a positional argument.
                    flattened_args = [signature.DUMMY_TYPE, state_ref
                                      ] + flattened_args
                workflow_inputs = serialization_context.make_workflow_inputs(
                    flattened_args)

                if readonly:
                    _actor_method = _wrap_readonly_actor_method(
                        actor_id, self.cls, method_name)
                    step_type = StepType.READONLY_ACTOR_METHOD
                else:
                    _actor_method = _wrap_actor_method(self.cls, method_name)
                    step_type = StepType.ACTOR_METHOD
                # TODO(suquark): Support actor options.
                workflow_data = WorkflowData(
                    func_body=_actor_method,
                    step_type=step_type,
                    inputs=workflow_inputs,
                    max_retries=1,
                    catch_exceptions=False,
                    ray_options={},
                    name=None,
                )
                wf = Workflow(workflow_data)
                return wf
Esempio n. 6
0
def test_dynamic_workflow_ref(workflow_start_regular_shared):
    @ray.remote
    def incr(x):
        return x + 1

    # This test also shows different "style" of running workflows.
    first_step = workflow.create(incr.bind(0))
    assert first_step.run("test_dynamic_workflow_ref") == 1
    second_step = workflow.create(incr.bind(WorkflowRef(first_step.step_id)))
    # Without rerun, it'll just return the previous result
    assert second_step.run("test_dynamic_workflow_ref") == 1
Esempio n. 7
0
def test_dynamic_workflow_ref(workflow_start_regular_shared):
    @ray.remote
    def incr(x):
        return x + 1

    # This test also shows different "style" of running workflows.
    assert workflow.run(incr.bind(0), workflow_id="test_dynamic_workflow_ref") == 1
    # Without rerun, it'll just return the previous result
    assert (
        workflow.run(
            incr.bind(WorkflowRef("incr")), workflow_id="test_dynamic_workflow_ref"
        )
        == 1
    )
Esempio n. 8
0
 async def _post_process_ready_task(
     self,
     task_id: TaskID,
     metadata: WorkflowExecutionMetadata,
     output_ref: WorkflowRef,
 ) -> None:
     state = self._state
     state.task_retries.pop(task_id, None)
     if metadata.is_output_workflow:  # The task returns a continuation
         sub_workflow_state: WorkflowExecutionState = await output_ref.ref
         # init the context just for "sub_workflow_state"
         sub_workflow_state.init_context(state.task_context[task_id])
         state.merge_state(sub_workflow_state)
         # build up runtime dependency
         continuation_task_id = sub_workflow_state.output_task_id
         state.append_continuation(task_id, continuation_task_id)
         # Migrate callbacks - all continuation callbacks are moved
         # under the root of continuation, so when the continuation
         # completes, all callbacks in the continuation can be triggered.
         if continuation_task_id in self._task_done_callbacks:
             self._task_done_callbacks[
                 state.continuation_root[continuation_task_id]].extend(
                     self._task_done_callbacks.pop(continuation_task_id))
         state.construct_scheduling_plan(sub_workflow_state.output_task_id)
     else:  # The task returns a normal object
         target_task_id = state.continuation_root.get(task_id, task_id)
         state.output_map[target_task_id] = output_ref
         if state.tasks[task_id].options.checkpoint:
             state.checkpoint_map[target_task_id] = WorkflowRef(task_id)
         state.done_tasks.add(target_task_id)
         # TODO(suquark): cleanup callbacks when a result is set?
         if target_task_id in self._task_done_callbacks:
             for callback in self._task_done_callbacks[target_task_id]:
                 callback.set_result(output_ref)
         for m in state.reference_set[target_task_id]:
             # we ensure that each reference corresponds to a pending input
             state.pending_input_set[m].remove(target_task_id)
             if not state.pending_input_set[m]:
                 state.append_frontier_to_run(m)
Esempio n. 9
0
    def _node_visitor(node: Any) -> Any:
        if isinstance(node, FunctionNode):
            bound_options = node._bound_options.copy()
            num_returns = bound_options.get("num_returns", 1)
            if num_returns is None:  # ray could use `None` as default value
                num_returns = 1
            if num_returns > 1:
                raise ValueError("Workflow steps can only have one return.")

            workflow_options = bound_options.pop("_metadata",
                                                 {}).get(WORKFLOW_OPTIONS, {})

            # If checkpoint option is not specified, inherit checkpoint
            # options from context (i.e. checkpoint options of the outer
            # step). If it is still not specified, it's True by default.
            checkpoint = workflow_options.get("checkpoint", None)
            if checkpoint is None:
                checkpoint = context.checkpoint if context is not None else True
            # When it returns a nested workflow, catch_exception
            # should be passed recursively.
            catch_exceptions = workflow_options.get("catch_exceptions", None)
            if catch_exceptions is None:
                # TODO(suquark): should we also handle exceptions from a "leaf node"
                #   in the continuation? For example, we have a workflow
                #   > @ray.remote
                #   > def A(): pass
                #   > @ray.remote
                #   > def B(x): return x
                #   > @ray.remote
                #   > def C(x): return workflow.continuation(B.bind(A.bind()))
                #   > dag = C.options(**workflow.options(catch_exceptions=True)).bind()
                #   Should C catches exceptions of A?
                if node.get_stable_uuid() == dag_node.get_stable_uuid():
                    # 'catch_exception' context should be passed down to
                    # its direct continuation task.
                    # In this case, the direct continuation is the output node.
                    catch_exceptions = (context.catch_exceptions
                                        if context is not None else False)
                else:
                    catch_exceptions = False

            max_retries = bound_options.get("max_retries", 3)
            if not isinstance(max_retries, int) or max_retries < -1:
                raise ValueError(
                    "'max_retries' only accepts 0, -1 or a positive integer.")

            step_options = WorkflowStepRuntimeOptions(
                step_type=StepType.FUNCTION,
                catch_exceptions=catch_exceptions,
                max_retries=max_retries,
                allow_inplace=False,
                checkpoint=checkpoint,
                ray_options=bound_options,
            )

            workflow_refs: List[WorkflowRef] = []
            with serialization_context.workflow_args_serialization_context(
                    workflow_refs):
                _func_signature = signature.extract_signature(node._body)
                flattened_args = signature.flatten_args(
                    _func_signature, node._bound_args, node._bound_kwargs)
                # NOTE: When calling 'ray.put', we trigger python object
                # serialization. Under our serialization context,
                # Workflows are separated from the arguments,
                # leaving a placeholder object with all other python objects.
                # Then we put the placeholder object to object store,
                # so it won't be mutated later. This guarantees correct
                # semantics. See "tests/test_variable_mutable.py" as
                # an example.
                input_placeholder: ray.ObjectRef = ray.put(flattened_args)

            name = workflow_options.get("name")
            if name is None:
                name = f"{get_module(node._body)}.{slugify(get_qualname(node._body))}"
            task_id = ray.get(mgr.gen_step_id.remote(workflow_id, name))
            state.add_dependencies(task_id, [s.task_id for s in workflow_refs])
            state.task_input_args[task_id] = input_placeholder

            user_metadata = workflow_options.pop("metadata", {})
            validate_user_metadata(user_metadata)
            state.tasks[task_id] = Task(
                name=name,
                options=step_options,
                user_metadata=user_metadata,
                func_body=node._body,
            )
            return WorkflowRef(task_id)

        if isinstance(node, InputAttributeNode):
            return node._execute_impl()  # get data from input node
        if isinstance(node, InputNode):
            return input_context  # replace input node with input data
        if not isinstance(node, DAGNode):
            return node  # return normal objects
        raise TypeError(f"Unsupported DAG node: {node}")
Esempio n. 10
0
    async def _handle_ready_task(self, fut: asyncio.Future, workflow_id: str,
                                 wf_store: "WorkflowStorage") -> None:
        """Handle ready task, especially about its exception."""
        state = self._state
        output_ref = state.pop_running_frontier(fut)
        task_id = output_ref.task_id
        try:
            metadata: WorkflowExecutionMetadata = fut.result()
            state.task_execution_metadata[task_id].finish_time = time.time()
            logger.info(f"Task status [{WorkflowStatus.SUCCESSFUL}]\t"
                        f"[{workflow_id}@{task_id}]")
            await self._post_process_ready_task(task_id, metadata, output_ref)
        except asyncio.CancelledError:
            # NOTE: We must update the workflow status before broadcasting
            # the exception. Otherwise, the workflow status would still be
            # 'RUNNING' if check the status immediately after cancellation.
            wf_store.update_workflow_status(WorkflowStatus.CANCELED)
            logger.warning(f"Workflow '{workflow_id}' is cancelled.")
            # broadcasting cancellation to all outputs
            err = WorkflowCancellationError(workflow_id)
            self._broadcast_exception(err)
            raise err from None
        except Exception as e:
            is_application_error = False
            if isinstance(e, RayTaskError):
                reason = "an exception raised by the task"
                is_application_error = True
            elif isinstance(e, RayError):
                reason = "a system error"
            else:
                reason = "an unknown error"
            logger.error(
                f"Task status [{WorkflowStatus.FAILED}] due to {reason}.\t"
                f"[{workflow_id}@{task_id}]")

            # on error, the error is caught by this task
            exception_catching_task_id = None
            # lookup a creator task that catches the exception
            if is_application_error:
                for t, task in self._iter_callstack(task_id):
                    if task.options.catch_exceptions:
                        exception_catching_task_id = t
                        break

            if exception_catching_task_id is None:
                # NOTE: We must update the workflow status before broadcasting
                # the exception. Otherwise, the workflow status would still be
                # 'RUNNING' if check the status immediately after the exception.
                wf_store.update_workflow_status(WorkflowStatus.FAILED)
                logger.error(f"Workflow '{workflow_id}' failed due to {e}")
                err = WorkflowExecutionError(workflow_id)
                err.__cause__ = e  # chain exceptions
                self._broadcast_exception(err)
                raise err

            logger.info(
                f"Exception raised by '{workflow_id}@{task_id}' is caught by "
                f"'{workflow_id}@{exception_catching_task_id}'")
            # assign output to exception catching task;
            # compose output with caught exception
            await self._post_process_ready_task(
                exception_catching_task_id,
                metadata=WorkflowExecutionMetadata(),
                output_ref=WorkflowRef(task_id, ray.put((None, e))),
            )