示例#1
0
def _orchestrate_stop_initiated_pipeline(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> None:
    """Orchestrates stop initiated pipeline."""
    stop_reason = pipeline_state.stop_initiated_reason()
    assert stop_reason is not None
    pipeline = pipeline_state.pipeline
    has_active_executions = False
    for node in pstate.get_all_pipeline_nodes(pipeline):
        if service_job_manager.is_pure_service_node(pipeline_state,
                                                    node.node_info.id):
            service_job_manager.stop_node_services(pipeline_state,
                                                   node.node_info.id)
        elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node,
                                              task_queue):
            has_active_executions = True
        elif service_job_manager.is_mixed_service_node(pipeline_state,
                                                       node.node_info.id):
            service_job_manager.stop_node_services(pipeline_state,
                                                   node.node_info.id)
    if not has_active_executions:
        with pipeline_state:
            # Update pipeline execution state in MLMD.
            pipeline_state.update_pipeline_execution_state(stop_reason)
示例#2
0
def _process_stop_initiated_pipelines(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: Optional[service_jobs.ServiceJobManager],
        pipeline_states: Sequence[pstate.PipelineState]) -> None:
    """Processes stop initiated pipelines."""
    for pipeline_state in pipeline_states:
        pipeline = pipeline_state.pipeline
        pure_service_node_ids = _get_pure_service_node_ids(
            service_job_manager,
            pipeline_state) if service_job_manager else set()
        execution = pipeline_state.execution
        has_active_executions = False
        for node in pstate.get_all_pipeline_nodes(pipeline):
            if node.node_info.id not in pure_service_node_ids:
                if _maybe_enqueue_cancellation_task(mlmd_handle, pipeline,
                                                    node, task_queue):
                    has_active_executions = True
        if not has_active_executions:
            if service_job_manager is not None:
                # Stop all the services associated with the pipeline.
                service_job_manager.stop_services(pipeline_state)
            # Update pipeline execution state in MLMD.
            updated_execution = copy.deepcopy(execution)
            updated_execution.last_known_state = metadata_store_pb2.Execution.CANCELED
            mlmd_handle.store.put_executions([updated_execution])
示例#3
0
def _cancel_nodes(mlmd_handle: metadata.Metadata,
                  task_queue: tq.TaskQueue,
                  service_job_manager: service_jobs.ServiceJobManager,
                  pipeline_state: pstate.PipelineState,
                  pause: bool,
                  reload_node_ids: Optional[List[str]] = None) -> bool:
    """Cancels pipeline nodes and returns `True` if any node is currently active."""
    pipeline = pipeline_state.pipeline
    is_active = False
    for node in pstate.get_all_pipeline_nodes(pipeline):
        # TODO(b/217584342): Partial reload which excludes service nodes is not
        # fully supported in async pipelines since we don't have a mechanism to
        # reload them later for new executions.
        if reload_node_ids is not None and node.node_info.id not in reload_node_ids:
            continue
        if service_job_manager.is_pure_service_node(pipeline_state,
                                                    node.node_info.id):
            if not service_job_manager.stop_node_services(
                    pipeline_state, node.node_info.id):
                is_active = True
        elif _maybe_enqueue_cancellation_task(mlmd_handle,
                                              pipeline,
                                              node,
                                              task_queue,
                                              pause=pause):
            is_active = True
        elif service_job_manager.is_mixed_service_node(pipeline_state,
                                                       node.node_info.id):
            if not service_job_manager.stop_node_services(
                    pipeline_state, node.node_info.id):
                is_active = True
    return is_active
示例#4
0
def _get_pure_service_node_ids(
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> Set[str]:
    result = set()
    for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline):
        if service_job_manager.is_pure_service_node(pipeline_state,
                                                    node.node_info.id):
            result.add(node.node_info.id)
    return result
示例#5
0
def _get_stop_initiated_nodes(
    pipeline_state: pstate.PipelineState) -> List[pipeline_pb2.PipelineNode]:
  """Returns list of all stop initiated nodes."""
  nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline)
  result = []
  for node in nodes:
    node_uid = task_lib.NodeUid.from_pipeline_node(pipeline_state.pipeline,
                                                   node)
    if pipeline_state.node_stop_initiated_reason(node_uid) is not None:
      result.append(node)
  return result
示例#6
0
def stop_node(
        mlmd_handle: metadata.Metadata,
        node_uid: task_lib.NodeUid,
        timeout_secs: float = DEFAULT_WAIT_FOR_INACTIVATION_TIMEOUT_SECS
) -> None:
    """Stops a node in an async pipeline.

  Initiates a node stop operation and waits for the node execution to become
  inactive.

  Args:
    mlmd_handle: A handle to the MLMD db.
    node_uid: Uid of the node to be stopped.
    timeout_secs: Amount of time in seconds to wait for node to stop.

  Raises:
    status_lib.StatusNotOkError: Failure to stop the node.
  """
    with _PIPELINE_OPS_LOCK:
        with pstate.PipelineState.load(
                mlmd_handle, node_uid.pipeline_uid) as pipeline_state:
            nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline)
            filtered_nodes = [
                n for n in nodes if n.node_info.id == node_uid.node_id
            ]
            if len(filtered_nodes) != 1:
                raise status_lib.StatusNotOkError(
                    code=status_lib.Code.INTERNAL,
                    message=
                    (f'`stop_node` operation failed, unable to find node to stop: '
                     f'{node_uid}'))
            node = filtered_nodes[0]
            pipeline_state.initiate_node_stop(
                node_uid,
                status_lib.Status(code=status_lib.Code.CANCELLED,
                                  message='Cancellation requested by client.'))

        executions = task_gen_utils.get_executions(mlmd_handle, node)
        active_executions = [
            e for e in executions if execution_lib.is_execution_active(e)
        ]
        if not active_executions:
            # If there are no active executions, we're done.
            return
        if len(active_executions) > 1:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.INTERNAL,
                message=
                (f'Unexpected multiple active executions for node: {node_uid}'
                 ))
    _wait_for_inactivation(mlmd_handle,
                           active_executions[0],
                           timeout_secs=timeout_secs)
示例#7
0
def _get_node_infos(pipeline_state: pstate.PipelineState) -> List[_NodeInfo]:
    """Returns a list of `_NodeInfo` object for each node in the pipeline."""
    nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline)
    result: List[_NodeInfo] = []
    with pipeline_state:
        for node in nodes:
            node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline_state.pipeline, node)
            result.append(
                _NodeInfo(node=node,
                          state=pipeline_state.get_node_state(node_uid)))
    return result
示例#8
0
def resume_manual_node(mlmd_handle: metadata.Metadata,
                       node_uid: task_lib.NodeUid) -> None:
    """Resumes a manual node.

  Args:
    mlmd_handle: A handle to the MLMD db.
    node_uid: Uid of the manual node to be resumed.

  Raises:
    status_lib.StatusNotOkError: Failure to resume a manual node.
  """
    logging.info('Received request to resume manual node; node uid: %s',
                 node_uid)
    with pstate.PipelineState.load(mlmd_handle,
                                   node_uid.pipeline_uid) as pipeline_state:
        nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline)
        filtered_nodes = [
            n for n in nodes if n.node_info.id == node_uid.node_id
        ]
        if len(filtered_nodes) != 1:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.NOT_FOUND,
                message=(f'Unable to find manual node to resume: {node_uid}'))
        node = filtered_nodes[0]
        node_type = node.node_info.type.name
        if node_type != constants.MANUAL_NODE_TYPE:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.INVALID_ARGUMENT,
                message=('Unable to resume a non-manual node. '
                         f'Got non-manual node id: {node_uid}'))

    executions = task_gen_utils.get_executions(mlmd_handle, node)
    active_executions = [
        e for e in executions if execution_lib.is_execution_active(e)
    ]
    if not active_executions:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.NOT_FOUND,
            message=(
                f'Unable to find active manual node to resume: {node_uid}'))
    if len(active_executions) > 1:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.INTERNAL,
            message=(f'Unexpected multiple active executions for manual node: '
                     f'{node_uid}'))
    with mlmd_state.mlmd_execution_atomic_op(
            mlmd_handle=mlmd_handle,
            execution_id=active_executions[0].id) as execution:
        completed_state = manual_task_scheduler.ManualNodeState(
            state=manual_task_scheduler.ManualNodeState.COMPLETED)
        completed_state.set_mlmd_value(
            execution.custom_properties.get_or_create(
                manual_task_scheduler.NODE_STATE_PROPERTY_KEY))
示例#9
0
def _process_stop_initiated_pipelines(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        pipeline_states: Sequence[pstate.PipelineState]) -> None:
    """Processes stop initiated pipelines."""
    for pipeline_state in pipeline_states:
        pipeline = pipeline_state.pipeline
        execution = pipeline_state.execution
        has_active_executions = False
        for node in pstate.get_all_pipeline_nodes(pipeline):
            if _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node,
                                                task_queue):
                has_active_executions = True
        if not has_active_executions:
            updated_execution = copy.deepcopy(execution)
            updated_execution.last_known_state = metadata_store_pb2.Execution.CANCELED
            mlmd_handle.store.put_executions([updated_execution])
示例#10
0
def stop_node(mlmd_handle: metadata.Metadata,
              node_uid: task_lib.NodeUid,
              timeout_secs: Optional[float] = None) -> None:
    """Stops a node.

  Initiates a node stop operation and waits for the node execution to become
  inactive.

  Args:
    mlmd_handle: A handle to the MLMD db.
    node_uid: Uid of the node to be stopped.
    timeout_secs: Amount of time in seconds to wait for node to stop. If `None`,
      waits indefinitely.

  Raises:
    status_lib.StatusNotOkError: Failure to stop the node.
  """
    logging.info('Received request to stop node; node uid: %s', node_uid)
    with _PIPELINE_OPS_LOCK:
        with pstate.PipelineState.load(
                mlmd_handle, node_uid.pipeline_uid) as pipeline_state:
            nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline)
            filtered_nodes = [
                n for n in nodes if n.node_info.id == node_uid.node_id
            ]
            if len(filtered_nodes) != 1:
                raise status_lib.StatusNotOkError(
                    code=status_lib.Code.INTERNAL,
                    message=
                    (f'`stop_node` operation failed, unable to find node to stop: '
                     f'{node_uid}'))
            node = filtered_nodes[0]
            with pipeline_state.node_state_update_context(
                    node_uid) as node_state:
                if node_state.is_stoppable():
                    node_state.update(
                        pstate.NodeState.STOPPING,
                        status_lib.Status(
                            code=status_lib.Code.CANCELLED,
                            message='Cancellation requested by client.'))

    # Wait until the node is stopped or time out.
    _wait_for_node_inactivation(pipeline_state,
                                node_uid,
                                timeout_secs=timeout_secs)
示例#11
0
def _orchestrate_active_pipeline(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> None:
    """Orchestrates active pipeline."""
    pipeline = pipeline_state.pipeline
    with pipeline_state:
        assert pipeline_state.is_active()
        if pipeline_state.get_pipeline_execution_state() != (
                metadata_store_pb2.Execution.RUNNING):
            pipeline_state.set_pipeline_execution_state(
                metadata_store_pb2.Execution.RUNNING)
        orchestration_options = pipeline_state.get_orchestration_options()
        logging.info('Orchestration options: %s', orchestration_options)
        deadline_secs = orchestration_options.deadline_secs
        if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                and deadline_secs > 0 and time.time() -
                pipeline_state.pipeline_creation_time_secs_since_epoch() >
                deadline_secs):
            logging.error(
                'Aborting pipeline due to exceeding deadline (%s secs); '
                'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid)
            pipeline_state.initiate_stop(
                status_lib.Status(
                    code=status_lib.Code.DEADLINE_EXCEEDED,
                    message=('Pipeline aborted due to exceeding deadline '
                             f'({deadline_secs} secs)')))
            return

    def _filter_by_state(node_infos: List[_NodeInfo],
                         state_str: str) -> List[_NodeInfo]:
        return [n for n in node_infos if n.state.state == state_str]

    node_infos = _get_node_infos(pipeline_state)
    stopping_node_infos = _filter_by_state(node_infos,
                                           pstate.NodeState.STOPPING)

    # Tracks nodes stopped in the current iteration.
    stopped_node_infos: List[_NodeInfo] = []

    # Create cancellation tasks for nodes in state STOPPING.
    for node_info in stopping_node_infos:
        if service_job_manager.is_pure_service_node(
                pipeline_state, node_info.node.node_info.id):
            if service_job_manager.stop_node_services(
                    pipeline_state, node_info.node.node_info.id):
                stopped_node_infos.append(node_info)
        elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline,
                                              node_info.node, task_queue):
            pass
        elif service_job_manager.is_mixed_service_node(
                pipeline_state, node_info.node.node_info.id):
            if service_job_manager.stop_node_services(
                    pipeline_state, node_info.node.node_info.id):
                stopped_node_infos.append(node_info)
        else:
            stopped_node_infos.append(node_info)

    # Change the state of stopped nodes from STOPPING to STOPPED.
    if stopped_node_infos:
        with pipeline_state:
            for node_info in stopped_node_infos:
                node_uid = task_lib.NodeUid.from_pipeline_node(
                    pipeline, node_info.node)
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STOPPED,
                                      node_state.status)

    # Initialize task generator for the pipeline.
    if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle,
            task_queue.contains_task_id,
            service_job_manager,
            fail_fast=orchestration_options.fail_fast)
    elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, task_queue.contains_task_id, service_job_manager)
    else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    tasks = generator.generate(pipeline_state)

    with pipeline_state:
        # Handle all the UpdateNodeStateTasks by updating node states.
        for task in tasks:
            if task_lib.is_update_node_state_task(task):
                task = typing.cast(task_lib.UpdateNodeStateTask, task)
                with pipeline_state.node_state_update_context(
                        task.node_uid) as node_state:
                    node_state.update(task.state, task.status)

        tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)]

        # If there are still nodes in state STARTING, change them to STARTED.
        for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline):
            node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline_state.pipeline, node)
            with pipeline_state.node_state_update_context(
                    node_uid) as node_state:
                if node_state.state == pstate.NodeState.STARTING:
                    node_state.update(pstate.NodeState.STARTED)

        for task in tasks:
            if task_lib.is_exec_node_task(task):
                task = typing.cast(task_lib.ExecNodeTask, task)
                task_queue.enqueue(task)
            else:
                assert task_lib.is_finalize_pipeline_task(task)
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                assert len(tasks) == 1
                task = typing.cast(task_lib.FinalizePipelineTask, task)
                if task.status.code == status_lib.Code.OK:
                    logging.info('Pipeline run successful; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                else:
                    logging.info('Pipeline run failed; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                pipeline_state.initiate_stop(task.status)
示例#12
0
def resume_pipeline(mlmd_handle: metadata.Metadata,
                    pipeline: pipeline_pb2.Pipeline) -> pstate.PipelineState:
    """Resumes a pipeline run from previously failed nodes.

  Upon success, MLMD is updated to signal that the pipeline must be started.

  Args:
    mlmd_handle: A handle to the MLMD db.
    pipeline: IR of the pipeline to resume.

  Returns:
    The `PipelineState` object upon success.

  Raises:
    status_lib.StatusNotOkError: Failure to resume pipeline. With code
      `ALREADY_EXISTS` if a pipeline is already running. With code
      `status_lib.Code.FAILED_PRECONDITION` if a previous pipeline run
      is not found for resuming.
  """

    logging.info('Received request to resume pipeline; pipeline uid: %s',
                 task_lib.PipelineUid.from_pipeline(pipeline))
    if pipeline.execution_mode != pipeline_pb2.Pipeline.SYNC:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    latest_pipeline_view = None
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    views = pstate.PipelineView.load_all(mlmd_handle, pipeline_uid)
    for view in views:
        execution = view.execution
        if execution_lib.is_execution_active(execution):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=(
                    f'Can not resume pipeline. An active pipeline is already '
                    f'running with uid {pipeline_uid}.'))
        if (not latest_pipeline_view or execution.create_time_since_epoch >
                latest_pipeline_view.execution.create_time_since_epoch):
            latest_pipeline_view = view

    if not latest_pipeline_view:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.NOT_FOUND,
            message='Pipeline failed to resume. No previous pipeline run found.'
        )
    if latest_pipeline_view.pipeline.execution_mode != pipeline_pb2.Pipeline.SYNC:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=
            (f'Only SYNC pipeline execution modes supported; previous pipeline '
             f'run has execution mode: '
             f'{latest_pipeline_view.pipeline.execution_mode}'))

    # Get succeeded nodes in latest pipeline run.
    latest_pipeline_node_states = latest_pipeline_view.get_node_states_dict()
    previously_succeeded_nodes = []
    for node, node_state in latest_pipeline_node_states.items():
        if node_state.is_success():
            previously_succeeded_nodes.append(node)
    pipeline_nodes = [
        node.node_info.id for node in pstate.get_all_pipeline_nodes(pipeline)
    ]
    latest_pipeline_snapshot_settings = pipeline_pb2.SnapshotSettings()
    latest_pipeline_snapshot_settings.latest_pipeline_run_strategy.SetInParent(
    )
    partial_run_option = pipeline_pb2.PartialRun(
        from_nodes=pipeline_nodes,
        to_nodes=pipeline_nodes,
        skip_nodes=previously_succeeded_nodes,
        snapshot_settings=latest_pipeline_snapshot_settings)

    return initiate_pipeline_start(mlmd_handle,
                                   pipeline,
                                   partial_run_option=partial_run_option)