def _process_stop_initiated_pipelines( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, pipeline_details: Sequence[_PipelineDetail]) -> None: """Processes stop initiated pipelines.""" for detail in pipeline_details: pipeline = detail.pipeline_state.pipeline execution = detail.pipeline_state.execution has_active_executions = False for node in _get_all_pipeline_nodes(pipeline): # If the node has an ExecNodeTask in the task queue, issue a cancellation. # Otherwise, if the node has an active execution in MLMD but no # ExecNodeTask enqueued, it may be due to orchestrator restart after # pipeline stop was initiated but before the schedulers could finish. So, # enqueue an ExecNodeTask with is_cancelled set to give a chance for the # scheduler to finish gracefully. exec_node_task_id = task_lib.exec_node_task_id_from_pipeline_node( pipeline, node) if task_queue.contains_task_id(exec_node_task_id): task_queue.enqueue( task_lib.CancelNodeTask( node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node))) has_active_executions = True else: executions = task_gen_utils.get_executions(mlmd_handle, node) exec_node_task = task_gen_utils.generate_task_from_active_execution( mlmd_handle, pipeline, node, executions, is_cancelled=True) if exec_node_task: task_queue.enqueue(exec_node_task) has_active_executions = True if not has_active_executions: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.CANCELED mlmd_handle.store.put_executions([updated_execution])
def _process_active_pipelines( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: Optional[service_jobs.ServiceJobManager], pipeline_states: Sequence[pstate.PipelineState]) -> None: """Processes active pipelines.""" for pipeline_state in pipeline_states: pipeline = pipeline_state.pipeline execution = pipeline_state.execution assert execution.last_known_state in ( metadata_store_pb2.Execution.NEW, metadata_store_pb2.Execution.RUNNING) if execution.last_known_state != metadata_store_pb2.Execution.RUNNING: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING mlmd_handle.store.put_executions([updated_execution]) if service_job_manager is not None: # Ensure all the required services are running. _ensure_services(service_job_manager, pipeline_state) pure_service_node_ids = _get_pure_service_node_ids( service_job_manager, pipeline_state) else: pure_service_node_ids = set() # Create cancellation tasks for stop-initiated nodes if necessary. stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state) for node in stop_initiated_nodes: if node.node_info.id not in pure_service_node_ids: _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue) ignore_node_ids = set( n.node_info.id for n in stop_initiated_nodes) | pure_service_node_ids # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, pipeline, task_queue.contains_task_id, ignore_node_ids) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, pipeline, task_queue.contains_task_id, ignore_node_ids) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message= (f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) # TODO(goutham): Consider concurrent task generation. tasks = generator.generate() for task in tasks: task_queue.enqueue(task)
def _process_active_pipelines( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, pipeline_details: Sequence[_PipelineDetail]) -> None: """Processes active pipelines.""" for detail in pipeline_details: execution = detail.pipeline_state.execution assert execution.last_known_state in (metadata_store_pb2.Execution.NEW, metadata_store_pb2.Execution.RUNNING) if execution.last_known_state != metadata_store_pb2.Execution.RUNNING: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING mlmd_handle.store.put_executions([updated_execution]) # TODO(goutham): Consider concurrent task generation. tasks = detail.generator.generate() for task in tasks: task_queue.enqueue(task)
def _maybe_enqueue_cancellation_task(mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, node: pipeline_pb2.PipelineNode, task_queue: tq.TaskQueue, pause: bool = False) -> bool: """Enqueues a node cancellation task if not already stopped. If the node has an ExecNodeTask in the task queue, issue a cancellation. Otherwise, when pause=False, if the node has an active execution in MLMD but no ExecNodeTask enqueued, it may be due to orchestrator restart after stopping was initiated but before the schedulers could finish. So, enqueue an ExecNodeTask with is_cancelled set to give a chance for the scheduler to finish gracefully. Args: mlmd_handle: A handle to the MLMD db. pipeline: The pipeline containing the node to cancel. node: The node to cancel. task_queue: A `TaskQueue` instance into which any cancellation tasks will be enqueued. pause: Whether the cancellation is to pause the node rather than cancelling the execution. Returns: `True` if a cancellation task was enqueued. `False` if node is already stopped or no cancellation was required. """ exec_node_task_id = task_lib.exec_node_task_id_from_pipeline_node( pipeline, node) if task_queue.contains_task_id(exec_node_task_id): task_queue.enqueue( task_lib.CancelNodeTask( node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node), pause=pause)) return True if not pause: executions = task_gen_utils.get_executions(mlmd_handle, node) exec_node_task = task_gen_utils.generate_task_from_active_execution( mlmd_handle, pipeline, node, executions, is_cancelled=True) if exec_node_task: task_queue.enqueue(exec_node_task) return True return False
def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline execution = pipeline_state.execution assert execution.last_known_state in (metadata_store_pb2.Execution.NEW, metadata_store_pb2.Execution.RUNNING) if execution.last_known_state != metadata_store_pb2.Execution.RUNNING: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING mlmd_handle.store.put_executions([updated_execution]) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, service_job_manager) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: # Create cancellation tasks for stop-initiated nodes if necessary. stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state) for node in stop_initiated_nodes: if service_job_manager.is_pure_service_node( pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, service_job_manager, set(n.node_info.id for n in stop_initiated_nodes)) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate() with pipeline_state: for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) elif task_lib.is_finalize_node_task(task): assert pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC task = typing.cast(task_lib.FinalizeNodeTask, task) pipeline_state.initiate_node_stop(task.node_uid, task.status) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)
def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline with pipeline_state: assert pipeline_state.is_active() if pipeline_state.get_pipeline_execution_state() != ( metadata_store_pb2.Execution.RUNNING): pipeline_state.set_pipeline_execution_state( metadata_store_pb2.Execution.RUNNING) orchestration_options = pipeline_state.get_orchestration_options() logging.info('Orchestration options: %s', orchestration_options) deadline_secs = orchestration_options.deadline_secs if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC and deadline_secs > 0 and time.time() - pipeline_state.pipeline_creation_time_secs_since_epoch() > deadline_secs): logging.error( 'Aborting pipeline due to exceeding deadline (%s secs); ' 'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid) pipeline_state.initiate_stop( status_lib.Status( code=status_lib.Code.DEADLINE_EXCEEDED, message=('Pipeline aborted due to exceeding deadline ' f'({deadline_secs} secs)'))) return def _filter_by_state(node_infos: List[_NodeInfo], state_str: str) -> List[_NodeInfo]: return [n for n in node_infos if n.state.state == state_str] node_infos = _get_node_infos(pipeline_state) stopping_node_infos = _filter_by_state(node_infos, pstate.NodeState.STOPPING) # Tracks nodes stopped in the current iteration. stopped_node_infos: List[_NodeInfo] = [] # Create cancellation tasks for nodes in state STOPPING. for node_info in stopping_node_infos: if service_job_manager.is_pure_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node_info.node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) else: stopped_node_infos.append(node_info) # Change the state of stopped nodes from STOPPING to STOPPED. if stopped_node_infos: with pipeline_state: for node_info in stopped_node_infos: node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, node_info.node) with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STOPPED, node_state.status) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager, fail_fast=orchestration_options.fail_fast) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate(pipeline_state) with pipeline_state: # Handle all the UpdateNodeStateTasks by updating node states. for task in tasks: if task_lib.is_update_node_state_task(task): task = typing.cast(task_lib.UpdateNodeStateTask, task) with pipeline_state.node_state_update_context( task.node_uid) as node_state: node_state.update(task.state, task.status) tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)] # If there are still nodes in state STARTING, change them to STARTED. for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline): node_uid = task_lib.NodeUid.from_pipeline_node( pipeline_state.pipeline, node) with pipeline_state.node_state_update_context( node_uid) as node_state: if node_state.state == pstate.NodeState.STARTING: node_state.update(pstate.NodeState.STARTED) for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)