Beispiel #1
0
def _cancel_nodes(mlmd_handle: metadata.Metadata,
                  task_queue: tq.TaskQueue,
                  service_job_manager: service_jobs.ServiceJobManager,
                  pipeline_state: pstate.PipelineState,
                  pause: bool,
                  reload_node_ids: Optional[List[str]] = None) -> bool:
    """Cancels pipeline nodes and returns `True` if any node is currently active."""
    pipeline = pipeline_state.pipeline
    is_active = False
    for node in pstate.get_all_pipeline_nodes(pipeline):
        # TODO(b/217584342): Partial reload which excludes service nodes is not
        # fully supported in async pipelines since we don't have a mechanism to
        # reload them later for new executions.
        if reload_node_ids is not None and node.node_info.id not in reload_node_ids:
            continue
        if service_job_manager.is_pure_service_node(pipeline_state,
                                                    node.node_info.id):
            if not service_job_manager.stop_node_services(
                    pipeline_state, node.node_info.id):
                is_active = True
        elif _maybe_enqueue_cancellation_task(mlmd_handle,
                                              pipeline,
                                              node,
                                              task_queue,
                                              pause=pause):
            is_active = True
        elif service_job_manager.is_mixed_service_node(pipeline_state,
                                                       node.node_info.id):
            if not service_job_manager.stop_node_services(
                    pipeline_state, node.node_info.id):
                is_active = True
    return is_active
Beispiel #2
0
def _ensure_services(service_jobs_manager: service_jobs.ServiceJobManager,
                     pipeline_state: pstate.PipelineState) -> None:
    failed_node_uids = service_jobs_manager.ensure_services(pipeline_state)
    if failed_node_uids:
        with pipeline_state:
            for node_uid in failed_node_uids:
                pipeline_state.initiate_node_stop(node_uid)
Beispiel #3
0
def _get_pure_service_node_ids(
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> Set[str]:
    result = set()
    for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline):
        if service_job_manager.is_pure_service_node(pipeline_state,
                                                    node.node_info.id):
            result.add(node.node_info.id)
    return result
Beispiel #4
0
def _orchestrate_stop_initiated_pipeline(
    mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
    service_job_manager: service_jobs.ServiceJobManager,
    pipeline_state: pstate.PipelineState) -> None:
  """Orchestrates stop initiated pipeline."""
  stop_reason = pipeline_state.stop_initiated_reason()
  assert stop_reason is not None
  pipeline = pipeline_state.pipeline
  has_active_executions = False
  for node in pstate.get_all_pipeline_nodes(pipeline):
    if service_job_manager.is_pure_service_node(pipeline_state,
                                                node.node_info.id):
      service_job_manager.stop_node_services(pipeline_state, node.node_info.id)
    else:
      if _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node,
                                          task_queue):
        has_active_executions = True
  if not has_active_executions:
    with pipeline_state:
      # Update pipeline execution state in MLMD.
      pipeline_state.update_pipeline_execution_state(stop_reason)
Beispiel #5
0
def _orchestrate_active_pipeline(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> None:
    """Orchestrates active pipeline."""
    pipeline = pipeline_state.pipeline
    execution = pipeline_state.execution
    assert execution.last_known_state in (metadata_store_pb2.Execution.NEW,
                                          metadata_store_pb2.Execution.RUNNING)
    if execution.last_known_state != metadata_store_pb2.Execution.RUNNING:
        updated_execution = copy.deepcopy(execution)
        updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING
        mlmd_handle.store.put_executions([updated_execution])

    # Initialize task generator for the pipeline.
    if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle, pipeline_state, task_queue.contains_task_id,
            service_job_manager)
    elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        # Create cancellation tasks for stop-initiated nodes if necessary.
        stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state)
        for node in stop_initiated_nodes:
            if service_job_manager.is_pure_service_node(
                    pipeline_state, node.node_info.id):
                service_job_manager.stop_node_services(pipeline_state,
                                                       node.node_info.id)
            elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node,
                                                  task_queue):
                pass
            elif service_job_manager.is_mixed_service_node(
                    pipeline_state, node.node_info.id):
                service_job_manager.stop_node_services(pipeline_state,
                                                       node.node_info.id)
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, pipeline_state, task_queue.contains_task_id,
            service_job_manager,
            set(n.node_info.id for n in stop_initiated_nodes))
    else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    tasks = generator.generate()

    with pipeline_state:
        for task in tasks:
            if task_lib.is_exec_node_task(task):
                task = typing.cast(task_lib.ExecNodeTask, task)
                task_queue.enqueue(task)
            elif task_lib.is_finalize_node_task(task):
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC
                task = typing.cast(task_lib.FinalizeNodeTask, task)
                pipeline_state.initiate_node_stop(task.node_uid, task.status)
            else:
                assert task_lib.is_finalize_pipeline_task(task)
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                assert len(tasks) == 1
                task = typing.cast(task_lib.FinalizePipelineTask, task)
                if task.status.code == status_lib.Code.OK:
                    logging.info('Pipeline run successful; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                else:
                    logging.info('Pipeline run failed; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                pipeline_state.initiate_stop(task.status)
Beispiel #6
0
def _orchestrate_active_pipeline(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> None:
    """Orchestrates active pipeline."""
    pipeline = pipeline_state.pipeline
    with pipeline_state:
        assert pipeline_state.is_active()
        if pipeline_state.get_pipeline_execution_state() != (
                metadata_store_pb2.Execution.RUNNING):
            pipeline_state.set_pipeline_execution_state(
                metadata_store_pb2.Execution.RUNNING)
        orchestration_options = pipeline_state.get_orchestration_options()
        logging.info('Orchestration options: %s', orchestration_options)
        deadline_secs = orchestration_options.deadline_secs
        if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                and deadline_secs > 0 and time.time() -
                pipeline_state.pipeline_creation_time_secs_since_epoch() >
                deadline_secs):
            logging.error(
                'Aborting pipeline due to exceeding deadline (%s secs); '
                'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid)
            pipeline_state.initiate_stop(
                status_lib.Status(
                    code=status_lib.Code.DEADLINE_EXCEEDED,
                    message=('Pipeline aborted due to exceeding deadline '
                             f'({deadline_secs} secs)')))
            return

    def _filter_by_state(node_infos: List[_NodeInfo],
                         state_str: str) -> List[_NodeInfo]:
        return [n for n in node_infos if n.state.state == state_str]

    node_infos = _get_node_infos(pipeline_state)
    stopping_node_infos = _filter_by_state(node_infos,
                                           pstate.NodeState.STOPPING)

    # Tracks nodes stopped in the current iteration.
    stopped_node_infos: List[_NodeInfo] = []

    # Create cancellation tasks for nodes in state STOPPING.
    for node_info in stopping_node_infos:
        if service_job_manager.is_pure_service_node(
                pipeline_state, node_info.node.node_info.id):
            if service_job_manager.stop_node_services(
                    pipeline_state, node_info.node.node_info.id):
                stopped_node_infos.append(node_info)
        elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline,
                                              node_info.node, task_queue):
            pass
        elif service_job_manager.is_mixed_service_node(
                pipeline_state, node_info.node.node_info.id):
            if service_job_manager.stop_node_services(
                    pipeline_state, node_info.node.node_info.id):
                stopped_node_infos.append(node_info)
        else:
            stopped_node_infos.append(node_info)

    # Change the state of stopped nodes from STOPPING to STOPPED.
    if stopped_node_infos:
        with pipeline_state:
            for node_info in stopped_node_infos:
                node_uid = task_lib.NodeUid.from_pipeline_node(
                    pipeline, node_info.node)
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STOPPED,
                                      node_state.status)

    # Initialize task generator for the pipeline.
    if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle,
            task_queue.contains_task_id,
            service_job_manager,
            fail_fast=orchestration_options.fail_fast)
    elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, task_queue.contains_task_id, service_job_manager)
    else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    tasks = generator.generate(pipeline_state)

    with pipeline_state:
        # Handle all the UpdateNodeStateTasks by updating node states.
        for task in tasks:
            if task_lib.is_update_node_state_task(task):
                task = typing.cast(task_lib.UpdateNodeStateTask, task)
                with pipeline_state.node_state_update_context(
                        task.node_uid) as node_state:
                    node_state.update(task.state, task.status)

        tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)]

        # If there are still nodes in state STARTING, change them to STARTED.
        for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline):
            node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline_state.pipeline, node)
            with pipeline_state.node_state_update_context(
                    node_uid) as node_state:
                if node_state.state == pstate.NodeState.STARTING:
                    node_state.update(pstate.NodeState.STARTED)

        for task in tasks:
            if task_lib.is_exec_node_task(task):
                task = typing.cast(task_lib.ExecNodeTask, task)
                task_queue.enqueue(task)
            else:
                assert task_lib.is_finalize_pipeline_task(task)
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                assert len(tasks) == 1
                task = typing.cast(task_lib.FinalizePipelineTask, task)
                if task.status.code == status_lib.Code.OK:
                    logging.info('Pipeline run successful; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                else:
                    logging.info('Pipeline run failed; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                pipeline_state.initiate_stop(task.status)