def _orchestrate_stop_initiated_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates stop initiated pipeline.""" stop_reason = pipeline_state.stop_initiated_reason() assert stop_reason is not None pipeline = pipeline_state.pipeline has_active_executions = False for node in pstate.get_all_pipeline_nodes(pipeline): if service_job_manager.is_pure_service_node(pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue): has_active_executions = True elif service_job_manager.is_mixed_service_node(pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) if not has_active_executions: with pipeline_state: # Update pipeline execution state in MLMD. pipeline_state.update_pipeline_execution_state(stop_reason)
def _process_stop_initiated_pipelines( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: Optional[service_jobs.ServiceJobManager], pipeline_states: Sequence[pstate.PipelineState]) -> None: """Processes stop initiated pipelines.""" for pipeline_state in pipeline_states: pipeline = pipeline_state.pipeline pure_service_node_ids = _get_pure_service_node_ids( service_job_manager, pipeline_state) if service_job_manager else set() execution = pipeline_state.execution has_active_executions = False for node in pstate.get_all_pipeline_nodes(pipeline): if node.node_info.id not in pure_service_node_ids: if _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue): has_active_executions = True if not has_active_executions: if service_job_manager is not None: # Stop all the services associated with the pipeline. service_job_manager.stop_services(pipeline_state) # Update pipeline execution state in MLMD. updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.CANCELED mlmd_handle.store.put_executions([updated_execution])
def _cancel_nodes(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState, pause: bool, reload_node_ids: Optional[List[str]] = None) -> bool: """Cancels pipeline nodes and returns `True` if any node is currently active.""" pipeline = pipeline_state.pipeline is_active = False for node in pstate.get_all_pipeline_nodes(pipeline): # TODO(b/217584342): Partial reload which excludes service nodes is not # fully supported in async pipelines since we don't have a mechanism to # reload them later for new executions. if reload_node_ids is not None and node.node_info.id not in reload_node_ids: continue if service_job_manager.is_pure_service_node(pipeline_state, node.node_info.id): if not service_job_manager.stop_node_services( pipeline_state, node.node_info.id): is_active = True elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue, pause=pause): is_active = True elif service_job_manager.is_mixed_service_node(pipeline_state, node.node_info.id): if not service_job_manager.stop_node_services( pipeline_state, node.node_info.id): is_active = True return is_active
def _get_pure_service_node_ids( service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> Set[str]: result = set() for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline): if service_job_manager.is_pure_service_node(pipeline_state, node.node_info.id): result.add(node.node_info.id) return result
def _get_stop_initiated_nodes( pipeline_state: pstate.PipelineState) -> List[pipeline_pb2.PipelineNode]: """Returns list of all stop initiated nodes.""" nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline) result = [] for node in nodes: node_uid = task_lib.NodeUid.from_pipeline_node(pipeline_state.pipeline, node) if pipeline_state.node_stop_initiated_reason(node_uid) is not None: result.append(node) return result
def stop_node( mlmd_handle: metadata.Metadata, node_uid: task_lib.NodeUid, timeout_secs: float = DEFAULT_WAIT_FOR_INACTIVATION_TIMEOUT_SECS ) -> None: """Stops a node in an async pipeline. Initiates a node stop operation and waits for the node execution to become inactive. Args: mlmd_handle: A handle to the MLMD db. node_uid: Uid of the node to be stopped. timeout_secs: Amount of time in seconds to wait for node to stop. Raises: status_lib.StatusNotOkError: Failure to stop the node. """ with _PIPELINE_OPS_LOCK: with pstate.PipelineState.load( mlmd_handle, node_uid.pipeline_uid) as pipeline_state: nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline) filtered_nodes = [ n for n in nodes if n.node_info.id == node_uid.node_id ] if len(filtered_nodes) != 1: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message= (f'`stop_node` operation failed, unable to find node to stop: ' f'{node_uid}')) node = filtered_nodes[0] pipeline_state.initiate_node_stop( node_uid, status_lib.Status(code=status_lib.Code.CANCELLED, message='Cancellation requested by client.')) executions = task_gen_utils.get_executions(mlmd_handle, node) active_executions = [ e for e in executions if execution_lib.is_execution_active(e) ] if not active_executions: # If there are no active executions, we're done. return if len(active_executions) > 1: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message= (f'Unexpected multiple active executions for node: {node_uid}' )) _wait_for_inactivation(mlmd_handle, active_executions[0], timeout_secs=timeout_secs)
def _get_node_infos(pipeline_state: pstate.PipelineState) -> List[_NodeInfo]: """Returns a list of `_NodeInfo` object for each node in the pipeline.""" nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline) result: List[_NodeInfo] = [] with pipeline_state: for node in nodes: node_uid = task_lib.NodeUid.from_pipeline_node( pipeline_state.pipeline, node) result.append( _NodeInfo(node=node, state=pipeline_state.get_node_state(node_uid))) return result
def resume_manual_node(mlmd_handle: metadata.Metadata, node_uid: task_lib.NodeUid) -> None: """Resumes a manual node. Args: mlmd_handle: A handle to the MLMD db. node_uid: Uid of the manual node to be resumed. Raises: status_lib.StatusNotOkError: Failure to resume a manual node. """ logging.info('Received request to resume manual node; node uid: %s', node_uid) with pstate.PipelineState.load(mlmd_handle, node_uid.pipeline_uid) as pipeline_state: nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline) filtered_nodes = [ n for n in nodes if n.node_info.id == node_uid.node_id ] if len(filtered_nodes) != 1: raise status_lib.StatusNotOkError( code=status_lib.Code.NOT_FOUND, message=(f'Unable to find manual node to resume: {node_uid}')) node = filtered_nodes[0] node_type = node.node_info.type.name if node_type != constants.MANUAL_NODE_TYPE: raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message=('Unable to resume a non-manual node. ' f'Got non-manual node id: {node_uid}')) executions = task_gen_utils.get_executions(mlmd_handle, node) active_executions = [ e for e in executions if execution_lib.is_execution_active(e) ] if not active_executions: raise status_lib.StatusNotOkError( code=status_lib.Code.NOT_FOUND, message=( f'Unable to find active manual node to resume: {node_uid}')) if len(active_executions) > 1: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message=(f'Unexpected multiple active executions for manual node: ' f'{node_uid}')) with mlmd_state.mlmd_execution_atomic_op( mlmd_handle=mlmd_handle, execution_id=active_executions[0].id) as execution: completed_state = manual_task_scheduler.ManualNodeState( state=manual_task_scheduler.ManualNodeState.COMPLETED) completed_state.set_mlmd_value( execution.custom_properties.get_or_create( manual_task_scheduler.NODE_STATE_PROPERTY_KEY))
def _process_stop_initiated_pipelines( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, pipeline_states: Sequence[pstate.PipelineState]) -> None: """Processes stop initiated pipelines.""" for pipeline_state in pipeline_states: pipeline = pipeline_state.pipeline execution = pipeline_state.execution has_active_executions = False for node in pstate.get_all_pipeline_nodes(pipeline): if _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue): has_active_executions = True if not has_active_executions: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.CANCELED mlmd_handle.store.put_executions([updated_execution])
def stop_node(mlmd_handle: metadata.Metadata, node_uid: task_lib.NodeUid, timeout_secs: Optional[float] = None) -> None: """Stops a node. Initiates a node stop operation and waits for the node execution to become inactive. Args: mlmd_handle: A handle to the MLMD db. node_uid: Uid of the node to be stopped. timeout_secs: Amount of time in seconds to wait for node to stop. If `None`, waits indefinitely. Raises: status_lib.StatusNotOkError: Failure to stop the node. """ logging.info('Received request to stop node; node uid: %s', node_uid) with _PIPELINE_OPS_LOCK: with pstate.PipelineState.load( mlmd_handle, node_uid.pipeline_uid) as pipeline_state: nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline) filtered_nodes = [ n for n in nodes if n.node_info.id == node_uid.node_id ] if len(filtered_nodes) != 1: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message= (f'`stop_node` operation failed, unable to find node to stop: ' f'{node_uid}')) node = filtered_nodes[0] with pipeline_state.node_state_update_context( node_uid) as node_state: if node_state.is_stoppable(): node_state.update( pstate.NodeState.STOPPING, status_lib.Status( code=status_lib.Code.CANCELLED, message='Cancellation requested by client.')) # Wait until the node is stopped or time out. _wait_for_node_inactivation(pipeline_state, node_uid, timeout_secs=timeout_secs)
def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline with pipeline_state: assert pipeline_state.is_active() if pipeline_state.get_pipeline_execution_state() != ( metadata_store_pb2.Execution.RUNNING): pipeline_state.set_pipeline_execution_state( metadata_store_pb2.Execution.RUNNING) orchestration_options = pipeline_state.get_orchestration_options() logging.info('Orchestration options: %s', orchestration_options) deadline_secs = orchestration_options.deadline_secs if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC and deadline_secs > 0 and time.time() - pipeline_state.pipeline_creation_time_secs_since_epoch() > deadline_secs): logging.error( 'Aborting pipeline due to exceeding deadline (%s secs); ' 'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid) pipeline_state.initiate_stop( status_lib.Status( code=status_lib.Code.DEADLINE_EXCEEDED, message=('Pipeline aborted due to exceeding deadline ' f'({deadline_secs} secs)'))) return def _filter_by_state(node_infos: List[_NodeInfo], state_str: str) -> List[_NodeInfo]: return [n for n in node_infos if n.state.state == state_str] node_infos = _get_node_infos(pipeline_state) stopping_node_infos = _filter_by_state(node_infos, pstate.NodeState.STOPPING) # Tracks nodes stopped in the current iteration. stopped_node_infos: List[_NodeInfo] = [] # Create cancellation tasks for nodes in state STOPPING. for node_info in stopping_node_infos: if service_job_manager.is_pure_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node_info.node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) else: stopped_node_infos.append(node_info) # Change the state of stopped nodes from STOPPING to STOPPED. if stopped_node_infos: with pipeline_state: for node_info in stopped_node_infos: node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, node_info.node) with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STOPPED, node_state.status) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager, fail_fast=orchestration_options.fail_fast) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate(pipeline_state) with pipeline_state: # Handle all the UpdateNodeStateTasks by updating node states. for task in tasks: if task_lib.is_update_node_state_task(task): task = typing.cast(task_lib.UpdateNodeStateTask, task) with pipeline_state.node_state_update_context( task.node_uid) as node_state: node_state.update(task.state, task.status) tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)] # If there are still nodes in state STARTING, change them to STARTED. for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline): node_uid = task_lib.NodeUid.from_pipeline_node( pipeline_state.pipeline, node) with pipeline_state.node_state_update_context( node_uid) as node_state: if node_state.state == pstate.NodeState.STARTING: node_state.update(pstate.NodeState.STARTED) for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)
def resume_pipeline(mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline) -> pstate.PipelineState: """Resumes a pipeline run from previously failed nodes. Upon success, MLMD is updated to signal that the pipeline must be started. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline to resume. Returns: The `PipelineState` object upon success. Raises: status_lib.StatusNotOkError: Failure to resume pipeline. With code `ALREADY_EXISTS` if a pipeline is already running. With code `status_lib.Code.FAILED_PRECONDITION` if a previous pipeline run is not found for resuming. """ logging.info('Received request to resume pipeline; pipeline uid: %s', task_lib.PipelineUid.from_pipeline(pipeline)) if pipeline.execution_mode != pipeline_pb2.Pipeline.SYNC: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) latest_pipeline_view = None pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) views = pstate.PipelineView.load_all(mlmd_handle, pipeline_uid) for view in views: execution = view.execution if execution_lib.is_execution_active(execution): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=( f'Can not resume pipeline. An active pipeline is already ' f'running with uid {pipeline_uid}.')) if (not latest_pipeline_view or execution.create_time_since_epoch > latest_pipeline_view.execution.create_time_since_epoch): latest_pipeline_view = view if not latest_pipeline_view: raise status_lib.StatusNotOkError( code=status_lib.Code.NOT_FOUND, message='Pipeline failed to resume. No previous pipeline run found.' ) if latest_pipeline_view.pipeline.execution_mode != pipeline_pb2.Pipeline.SYNC: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message= (f'Only SYNC pipeline execution modes supported; previous pipeline ' f'run has execution mode: ' f'{latest_pipeline_view.pipeline.execution_mode}')) # Get succeeded nodes in latest pipeline run. latest_pipeline_node_states = latest_pipeline_view.get_node_states_dict() previously_succeeded_nodes = [] for node, node_state in latest_pipeline_node_states.items(): if node_state.is_success(): previously_succeeded_nodes.append(node) pipeline_nodes = [ node.node_info.id for node in pstate.get_all_pipeline_nodes(pipeline) ] latest_pipeline_snapshot_settings = pipeline_pb2.SnapshotSettings() latest_pipeline_snapshot_settings.latest_pipeline_run_strategy.SetInParent( ) partial_run_option = pipeline_pb2.PartialRun( from_nodes=pipeline_nodes, to_nodes=pipeline_nodes, skip_nodes=previously_succeeded_nodes, snapshot_settings=latest_pipeline_snapshot_settings) return initiate_pipeline_start(mlmd_handle, pipeline, partial_run_option=partial_run_option)