def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline execution = pipeline_state.execution assert execution.last_known_state in (metadata_store_pb2.Execution.NEW, metadata_store_pb2.Execution.RUNNING) if execution.last_known_state != metadata_store_pb2.Execution.RUNNING: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING mlmd_handle.store.put_executions([updated_execution]) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, service_job_manager) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: # Create cancellation tasks for stop-initiated nodes if necessary. stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state) for node in stop_initiated_nodes: if service_job_manager.is_pure_service_node( pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, service_job_manager, set(n.node_info.id for n in stop_initiated_nodes)) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate() with pipeline_state: for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) elif task_lib.is_finalize_node_task(task): assert pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC task = typing.cast(task_lib.FinalizeNodeTask, task) pipeline_state.initiate_node_stop(task.node_uid, task.status) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)
def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline with pipeline_state: assert pipeline_state.is_active() if pipeline_state.get_pipeline_execution_state() != ( metadata_store_pb2.Execution.RUNNING): pipeline_state.set_pipeline_execution_state( metadata_store_pb2.Execution.RUNNING) orchestration_options = pipeline_state.get_orchestration_options() logging.info('Orchestration options: %s', orchestration_options) deadline_secs = orchestration_options.deadline_secs if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC and deadline_secs > 0 and time.time() - pipeline_state.pipeline_creation_time_secs_since_epoch() > deadline_secs): logging.error( 'Aborting pipeline due to exceeding deadline (%s secs); ' 'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid) pipeline_state.initiate_stop( status_lib.Status( code=status_lib.Code.DEADLINE_EXCEEDED, message=('Pipeline aborted due to exceeding deadline ' f'({deadline_secs} secs)'))) return def _filter_by_state(node_infos: List[_NodeInfo], state_str: str) -> List[_NodeInfo]: return [n for n in node_infos if n.state.state == state_str] node_infos = _get_node_infos(pipeline_state) stopping_node_infos = _filter_by_state(node_infos, pstate.NodeState.STOPPING) # Tracks nodes stopped in the current iteration. stopped_node_infos: List[_NodeInfo] = [] # Create cancellation tasks for nodes in state STOPPING. for node_info in stopping_node_infos: if service_job_manager.is_pure_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node_info.node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) else: stopped_node_infos.append(node_info) # Change the state of stopped nodes from STOPPING to STOPPED. if stopped_node_infos: with pipeline_state: for node_info in stopped_node_infos: node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, node_info.node) with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STOPPED, node_state.status) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager, fail_fast=orchestration_options.fail_fast) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate(pipeline_state) with pipeline_state: # Handle all the UpdateNodeStateTasks by updating node states. for task in tasks: if task_lib.is_update_node_state_task(task): task = typing.cast(task_lib.UpdateNodeStateTask, task) with pipeline_state.node_state_update_context( task.node_uid) as node_state: node_state.update(task.state, task.status) tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)] # If there are still nodes in state STARTING, change them to STARTED. for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline): node_uid = task_lib.NodeUid.from_pipeline_node( pipeline_state.pipeline, node) with pipeline_state.node_state_update_context( node_uid) as node_state: if node_state.state == pstate.NodeState.STARTING: node_state.update(pstate.NodeState.STARTED) for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)