def _cancel_nodes(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState, pause: bool, reload_node_ids: Optional[List[str]] = None) -> bool: """Cancels pipeline nodes and returns `True` if any node is currently active.""" pipeline = pipeline_state.pipeline is_active = False for node in pstate.get_all_pipeline_nodes(pipeline): # TODO(b/217584342): Partial reload which excludes service nodes is not # fully supported in async pipelines since we don't have a mechanism to # reload them later for new executions. if reload_node_ids is not None and node.node_info.id not in reload_node_ids: continue if service_job_manager.is_pure_service_node(pipeline_state, node.node_info.id): if not service_job_manager.stop_node_services( pipeline_state, node.node_info.id): is_active = True elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue, pause=pause): is_active = True elif service_job_manager.is_mixed_service_node(pipeline_state, node.node_info.id): if not service_job_manager.stop_node_services( pipeline_state, node.node_info.id): is_active = True return is_active
def _ensure_services(service_jobs_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: failed_node_uids = service_jobs_manager.ensure_services(pipeline_state) if failed_node_uids: with pipeline_state: for node_uid in failed_node_uids: pipeline_state.initiate_node_stop(node_uid)
def _get_pure_service_node_ids( service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> Set[str]: result = set() for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline): if service_job_manager.is_pure_service_node(pipeline_state, node.node_info.id): result.add(node.node_info.id) return result
def _orchestrate_stop_initiated_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates stop initiated pipeline.""" stop_reason = pipeline_state.stop_initiated_reason() assert stop_reason is not None pipeline = pipeline_state.pipeline has_active_executions = False for node in pstate.get_all_pipeline_nodes(pipeline): if service_job_manager.is_pure_service_node(pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) else: if _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue): has_active_executions = True if not has_active_executions: with pipeline_state: # Update pipeline execution state in MLMD. pipeline_state.update_pipeline_execution_state(stop_reason)
def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline execution = pipeline_state.execution assert execution.last_known_state in (metadata_store_pb2.Execution.NEW, metadata_store_pb2.Execution.RUNNING) if execution.last_known_state != metadata_store_pb2.Execution.RUNNING: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING mlmd_handle.store.put_executions([updated_execution]) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, service_job_manager) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: # Create cancellation tasks for stop-initiated nodes if necessary. stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state) for node in stop_initiated_nodes: if service_job_manager.is_pure_service_node( pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, service_job_manager, set(n.node_info.id for n in stop_initiated_nodes)) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate() with pipeline_state: for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) elif task_lib.is_finalize_node_task(task): assert pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC task = typing.cast(task_lib.FinalizeNodeTask, task) pipeline_state.initiate_node_stop(task.node_uid, task.status) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)
def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline with pipeline_state: assert pipeline_state.is_active() if pipeline_state.get_pipeline_execution_state() != ( metadata_store_pb2.Execution.RUNNING): pipeline_state.set_pipeline_execution_state( metadata_store_pb2.Execution.RUNNING) orchestration_options = pipeline_state.get_orchestration_options() logging.info('Orchestration options: %s', orchestration_options) deadline_secs = orchestration_options.deadline_secs if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC and deadline_secs > 0 and time.time() - pipeline_state.pipeline_creation_time_secs_since_epoch() > deadline_secs): logging.error( 'Aborting pipeline due to exceeding deadline (%s secs); ' 'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid) pipeline_state.initiate_stop( status_lib.Status( code=status_lib.Code.DEADLINE_EXCEEDED, message=('Pipeline aborted due to exceeding deadline ' f'({deadline_secs} secs)'))) return def _filter_by_state(node_infos: List[_NodeInfo], state_str: str) -> List[_NodeInfo]: return [n for n in node_infos if n.state.state == state_str] node_infos = _get_node_infos(pipeline_state) stopping_node_infos = _filter_by_state(node_infos, pstate.NodeState.STOPPING) # Tracks nodes stopped in the current iteration. stopped_node_infos: List[_NodeInfo] = [] # Create cancellation tasks for nodes in state STOPPING. for node_info in stopping_node_infos: if service_job_manager.is_pure_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node_info.node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) else: stopped_node_infos.append(node_info) # Change the state of stopped nodes from STOPPING to STOPPED. if stopped_node_infos: with pipeline_state: for node_info in stopped_node_infos: node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, node_info.node) with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STOPPED, node_state.status) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager, fail_fast=orchestration_options.fail_fast) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate(pipeline_state) with pipeline_state: # Handle all the UpdateNodeStateTasks by updating node states. for task in tasks: if task_lib.is_update_node_state_task(task): task = typing.cast(task_lib.UpdateNodeStateTask, task) with pipeline_state.node_state_update_context( task.node_uid) as node_state: node_state.update(task.state, task.status) tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)] # If there are still nodes in state STARTING, change them to STARTED. for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline): node_uid = task_lib.NodeUid.from_pipeline_node( pipeline_state.pipeline, node) with pipeline_state.node_state_update_context( node_uid) as node_state: if node_state.state == pstate.NodeState.STARTING: node_state.update(pstate.NodeState.STARTED) for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)