예제 #1
0
def _process_stop_initiated_pipelines(
    mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
    pipeline_details: Sequence[_PipelineDetail]) -> None:
  """Processes stop initiated pipelines."""
  for detail in pipeline_details:
    pipeline = detail.pipeline_state.pipeline
    execution = detail.pipeline_state.execution
    has_active_executions = False
    for node in _get_all_pipeline_nodes(pipeline):
      # If the node has an ExecNodeTask in the task queue, issue a cancellation.
      # Otherwise, if the node has an active execution in MLMD but no
      # ExecNodeTask enqueued, it may be due to orchestrator restart after
      # pipeline stop was initiated but before the schedulers could finish. So,
      # enqueue an ExecNodeTask with is_cancelled set to give a chance for the
      # scheduler to finish gracefully.
      exec_node_task_id = task_lib.exec_node_task_id_from_pipeline_node(
          pipeline, node)
      if task_queue.contains_task_id(exec_node_task_id):
        task_queue.enqueue(
            task_lib.CancelNodeTask(
                node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node)))
        has_active_executions = True
      else:
        executions = task_gen_utils.get_executions(mlmd_handle, node)
        exec_node_task = task_gen_utils.generate_task_from_active_execution(
            mlmd_handle, pipeline, node, executions, is_cancelled=True)
        if exec_node_task:
          task_queue.enqueue(exec_node_task)
          has_active_executions = True
    if not has_active_executions:
      updated_execution = copy.deepcopy(execution)
      updated_execution.last_known_state = metadata_store_pb2.Execution.CANCELED
      mlmd_handle.store.put_executions([updated_execution])
예제 #2
0
def _process_active_pipelines(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: Optional[service_jobs.ServiceJobManager],
        pipeline_states: Sequence[pstate.PipelineState]) -> None:
    """Processes active pipelines."""
    for pipeline_state in pipeline_states:
        pipeline = pipeline_state.pipeline
        execution = pipeline_state.execution
        assert execution.last_known_state in (
            metadata_store_pb2.Execution.NEW,
            metadata_store_pb2.Execution.RUNNING)
        if execution.last_known_state != metadata_store_pb2.Execution.RUNNING:
            updated_execution = copy.deepcopy(execution)
            updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING
            mlmd_handle.store.put_executions([updated_execution])

        if service_job_manager is not None:
            # Ensure all the required services are running.
            _ensure_services(service_job_manager, pipeline_state)
            pure_service_node_ids = _get_pure_service_node_ids(
                service_job_manager, pipeline_state)
        else:
            pure_service_node_ids = set()

        # Create cancellation tasks for stop-initiated nodes if necessary.
        stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state)
        for node in stop_initiated_nodes:
            if node.node_info.id not in pure_service_node_ids:
                _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node,
                                                 task_queue)

        ignore_node_ids = set(
            n.node_info.id
            for n in stop_initiated_nodes) | pure_service_node_ids

        # Initialize task generator for the pipeline.
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
                mlmd_handle, pipeline, task_queue.contains_task_id,
                ignore_node_ids)
        elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
            generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
                mlmd_handle, pipeline, task_queue.contains_task_id,
                ignore_node_ids)
        else:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.FAILED_PRECONDITION,
                message=
                (f'Only SYNC and ASYNC pipeline execution modes supported; '
                 f'found pipeline with execution mode: {pipeline.execution_mode}'
                 ))

        # TODO(goutham): Consider concurrent task generation.
        tasks = generator.generate()
        for task in tasks:
            task_queue.enqueue(task)
예제 #3
0
def _process_active_pipelines(
    mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
    pipeline_details: Sequence[_PipelineDetail]) -> None:
  """Processes active pipelines."""
  for detail in pipeline_details:
    execution = detail.pipeline_state.execution
    assert execution.last_known_state in (metadata_store_pb2.Execution.NEW,
                                          metadata_store_pb2.Execution.RUNNING)
    if execution.last_known_state != metadata_store_pb2.Execution.RUNNING:
      updated_execution = copy.deepcopy(execution)
      updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING
      mlmd_handle.store.put_executions([updated_execution])

    # TODO(goutham): Consider concurrent task generation.
    tasks = detail.generator.generate()
    for task in tasks:
      task_queue.enqueue(task)
예제 #4
0
def _maybe_enqueue_cancellation_task(mlmd_handle: metadata.Metadata,
                                     pipeline: pipeline_pb2.Pipeline,
                                     node: pipeline_pb2.PipelineNode,
                                     task_queue: tq.TaskQueue,
                                     pause: bool = False) -> bool:
    """Enqueues a node cancellation task if not already stopped.

  If the node has an ExecNodeTask in the task queue, issue a cancellation.
  Otherwise, when pause=False, if the node has an active execution in MLMD but
  no ExecNodeTask enqueued, it may be due to orchestrator restart after stopping
  was initiated but before the schedulers could finish. So, enqueue an
  ExecNodeTask with is_cancelled set to give a chance for the scheduler to
  finish gracefully.

  Args:
    mlmd_handle: A handle to the MLMD db.
    pipeline: The pipeline containing the node to cancel.
    node: The node to cancel.
    task_queue: A `TaskQueue` instance into which any cancellation tasks will be
      enqueued.
    pause: Whether the cancellation is to pause the node rather than cancelling
      the execution.

  Returns:
    `True` if a cancellation task was enqueued. `False` if node is already
    stopped or no cancellation was required.
  """
    exec_node_task_id = task_lib.exec_node_task_id_from_pipeline_node(
        pipeline, node)
    if task_queue.contains_task_id(exec_node_task_id):
        task_queue.enqueue(
            task_lib.CancelNodeTask(
                node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node),
                pause=pause))
        return True
    if not pause:
        executions = task_gen_utils.get_executions(mlmd_handle, node)
        exec_node_task = task_gen_utils.generate_task_from_active_execution(
            mlmd_handle, pipeline, node, executions, is_cancelled=True)
        if exec_node_task:
            task_queue.enqueue(exec_node_task)
            return True
    return False
예제 #5
0
def _orchestrate_active_pipeline(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> None:
    """Orchestrates active pipeline."""
    pipeline = pipeline_state.pipeline
    execution = pipeline_state.execution
    assert execution.last_known_state in (metadata_store_pb2.Execution.NEW,
                                          metadata_store_pb2.Execution.RUNNING)
    if execution.last_known_state != metadata_store_pb2.Execution.RUNNING:
        updated_execution = copy.deepcopy(execution)
        updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING
        mlmd_handle.store.put_executions([updated_execution])

    # Initialize task generator for the pipeline.
    if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle, pipeline_state, task_queue.contains_task_id,
            service_job_manager)
    elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        # Create cancellation tasks for stop-initiated nodes if necessary.
        stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state)
        for node in stop_initiated_nodes:
            if service_job_manager.is_pure_service_node(
                    pipeline_state, node.node_info.id):
                service_job_manager.stop_node_services(pipeline_state,
                                                       node.node_info.id)
            elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node,
                                                  task_queue):
                pass
            elif service_job_manager.is_mixed_service_node(
                    pipeline_state, node.node_info.id):
                service_job_manager.stop_node_services(pipeline_state,
                                                       node.node_info.id)
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, pipeline_state, task_queue.contains_task_id,
            service_job_manager,
            set(n.node_info.id for n in stop_initiated_nodes))
    else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    tasks = generator.generate()

    with pipeline_state:
        for task in tasks:
            if task_lib.is_exec_node_task(task):
                task = typing.cast(task_lib.ExecNodeTask, task)
                task_queue.enqueue(task)
            elif task_lib.is_finalize_node_task(task):
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC
                task = typing.cast(task_lib.FinalizeNodeTask, task)
                pipeline_state.initiate_node_stop(task.node_uid, task.status)
            else:
                assert task_lib.is_finalize_pipeline_task(task)
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                assert len(tasks) == 1
                task = typing.cast(task_lib.FinalizePipelineTask, task)
                if task.status.code == status_lib.Code.OK:
                    logging.info('Pipeline run successful; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                else:
                    logging.info('Pipeline run failed; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                pipeline_state.initiate_stop(task.status)
예제 #6
0
def _orchestrate_active_pipeline(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> None:
    """Orchestrates active pipeline."""
    pipeline = pipeline_state.pipeline
    with pipeline_state:
        assert pipeline_state.is_active()
        if pipeline_state.get_pipeline_execution_state() != (
                metadata_store_pb2.Execution.RUNNING):
            pipeline_state.set_pipeline_execution_state(
                metadata_store_pb2.Execution.RUNNING)
        orchestration_options = pipeline_state.get_orchestration_options()
        logging.info('Orchestration options: %s', orchestration_options)
        deadline_secs = orchestration_options.deadline_secs
        if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                and deadline_secs > 0 and time.time() -
                pipeline_state.pipeline_creation_time_secs_since_epoch() >
                deadline_secs):
            logging.error(
                'Aborting pipeline due to exceeding deadline (%s secs); '
                'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid)
            pipeline_state.initiate_stop(
                status_lib.Status(
                    code=status_lib.Code.DEADLINE_EXCEEDED,
                    message=('Pipeline aborted due to exceeding deadline '
                             f'({deadline_secs} secs)')))
            return

    def _filter_by_state(node_infos: List[_NodeInfo],
                         state_str: str) -> List[_NodeInfo]:
        return [n for n in node_infos if n.state.state == state_str]

    node_infos = _get_node_infos(pipeline_state)
    stopping_node_infos = _filter_by_state(node_infos,
                                           pstate.NodeState.STOPPING)

    # Tracks nodes stopped in the current iteration.
    stopped_node_infos: List[_NodeInfo] = []

    # Create cancellation tasks for nodes in state STOPPING.
    for node_info in stopping_node_infos:
        if service_job_manager.is_pure_service_node(
                pipeline_state, node_info.node.node_info.id):
            if service_job_manager.stop_node_services(
                    pipeline_state, node_info.node.node_info.id):
                stopped_node_infos.append(node_info)
        elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline,
                                              node_info.node, task_queue):
            pass
        elif service_job_manager.is_mixed_service_node(
                pipeline_state, node_info.node.node_info.id):
            if service_job_manager.stop_node_services(
                    pipeline_state, node_info.node.node_info.id):
                stopped_node_infos.append(node_info)
        else:
            stopped_node_infos.append(node_info)

    # Change the state of stopped nodes from STOPPING to STOPPED.
    if stopped_node_infos:
        with pipeline_state:
            for node_info in stopped_node_infos:
                node_uid = task_lib.NodeUid.from_pipeline_node(
                    pipeline, node_info.node)
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STOPPED,
                                      node_state.status)

    # Initialize task generator for the pipeline.
    if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle,
            task_queue.contains_task_id,
            service_job_manager,
            fail_fast=orchestration_options.fail_fast)
    elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, task_queue.contains_task_id, service_job_manager)
    else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    tasks = generator.generate(pipeline_state)

    with pipeline_state:
        # Handle all the UpdateNodeStateTasks by updating node states.
        for task in tasks:
            if task_lib.is_update_node_state_task(task):
                task = typing.cast(task_lib.UpdateNodeStateTask, task)
                with pipeline_state.node_state_update_context(
                        task.node_uid) as node_state:
                    node_state.update(task.state, task.status)

        tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)]

        # If there are still nodes in state STARTING, change them to STARTED.
        for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline):
            node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline_state.pipeline, node)
            with pipeline_state.node_state_update_context(
                    node_uid) as node_state:
                if node_state.state == pstate.NodeState.STARTING:
                    node_state.update(pstate.NodeState.STARTED)

        for task in tasks:
            if task_lib.is_exec_node_task(task):
                task = typing.cast(task_lib.ExecNodeTask, task)
                task_queue.enqueue(task)
            else:
                assert task_lib.is_finalize_pipeline_task(task)
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                assert len(tasks) == 1
                task = typing.cast(task_lib.FinalizePipelineTask, task)
                if task.status.code == status_lib.Code.OK:
                    logging.info('Pipeline run successful; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                else:
                    logging.info('Pipeline run failed; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                pipeline_state.initiate_stop(task.status)