def _generate_and_test(self, use_task_queue, num_initial_executions, num_tasks_generated, num_new_executions, num_active_executions): """Generates tasks and tests the effects.""" with self._mlmd_connection as m: executions = m.store.get_executions() self.assertLen( executions, num_initial_executions, 'Expected {} execution(s) in MLMD.'.format( num_initial_executions)) task_gen = sptg.SyncPipelineTaskGenerator( m, self._pipeline, self._task_queue.contains_task_id) tasks = task_gen.generate() self.assertLen( tasks, num_tasks_generated, 'Expected {} task(s) to be generated.'.format( num_tasks_generated)) executions = m.store.get_executions() num_total_executions = num_initial_executions + num_new_executions self.assertLen( executions, num_total_executions, 'Expected {} execution(s) in MLMD.'.format( num_total_executions)) active_executions = [ e for e in executions if e.last_known_state == metadata_store_pb2.Execution.RUNNING ] self.assertLen( active_executions, num_active_executions, 'Expected {} active execution(s) in MLMD.'.format( num_active_executions)) if use_task_queue: for task in tasks: self._task_queue.enqueue(task) return tasks, active_executions
def _get_pipeline_details(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue) -> List[_PipelineDetail]: """Scans MLMD and returns pipeline details.""" result = [] contexts = mlmd_handle.store.get_contexts_by_type( _ORCHESTRATOR_RESERVED_ID) for context in contexts: active_executions = [ e for e in mlmd_handle.store.get_executions_by_context(context.id) if execution_lib.is_execution_active(e) ] if len(active_executions) > 1: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message=( f'Expected 1 but found {len(active_executions)} active ' f'executions for context named: {context.name}')) if not active_executions: continue execution = active_executions[0] # TODO(goutham): Instead of parsing the pipeline IR each time, we could # cache the parsed pipeline IR in `initiate_pipeline_start` and reuse it. pipeline_ir_b64 = common_utils.get_metadata_value( execution.properties[_PIPELINE_IR]) pipeline = pipeline_pb2.Pipeline() pipeline.ParseFromString(base64.b64decode(pipeline_ir_b64)) stop_initiated = _is_stop_initiated(execution) if stop_initiated: generator = None else: if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, pipeline, task_queue.contains_task_id) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, pipeline, task_queue.contains_task_id) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message= (f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) result.append( _PipelineDetail( context=context, execution=execution, pipeline=pipeline, pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), stop_initiated=stop_initiated, generator=generator)) return result
def _process_active_pipelines( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: Optional[service_jobs.ServiceJobManager], pipeline_states: Sequence[pstate.PipelineState]) -> None: """Processes active pipelines.""" for pipeline_state in pipeline_states: pipeline = pipeline_state.pipeline execution = pipeline_state.execution assert execution.last_known_state in ( metadata_store_pb2.Execution.NEW, metadata_store_pb2.Execution.RUNNING) if execution.last_known_state != metadata_store_pb2.Execution.RUNNING: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING mlmd_handle.store.put_executions([updated_execution]) if service_job_manager is not None: # Ensure all the required services are running. _ensure_services(service_job_manager, pipeline_state) pure_service_node_ids = _get_pure_service_node_ids( service_job_manager, pipeline_state) else: pure_service_node_ids = set() # Create cancellation tasks for stop-initiated nodes if necessary. stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state) for node in stop_initiated_nodes: if node.node_info.id not in pure_service_node_ids: _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue) ignore_node_ids = set( n.node_info.id for n in stop_initiated_nodes) | pure_service_node_ids # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, pipeline, task_queue.contains_task_id, ignore_node_ids) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, pipeline, task_queue.contains_task_id, ignore_node_ids) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message= (f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) # TODO(goutham): Consider concurrent task generation. tasks = generator.generate() for task in tasks: task_queue.enqueue(task)
def _test_no_tasks_generated_when_new(self): with self._mlmd_connection as m: task_gen = sptg.SyncPipelineTaskGenerator(m, self._pipeline, lambda _: False) tasks = task_gen.generate() self.assertEmpty( tasks, 'Expected no task generation since ExampleGen is ignored for task ' 'generation and dependent downstream nodes are ready.') self.assertEmpty( m.store.get_executions(), 'There must not be any registered executions since no tasks were ' 'generated.')
def _get_pipeline_details(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue) -> List[_PipelineDetail]: """Scans MLMD and returns pipeline details.""" result = [] contexts = pstate.get_orchestrator_contexts(mlmd_handle) for context in contexts: try: pipeline_state = pstate.PipelineState.load_from_orchestrator_context( mlmd_handle, context) except status_lib.StatusNotOkError as e: if e.code == status_lib.Code.NOT_FOUND: continue if pipeline_state.is_stop_initiated(): generator = None else: pipeline = pipeline_state.pipeline if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, pipeline, task_queue.contains_task_id) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, pipeline, task_queue.contains_task_id) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) result.append( _PipelineDetail(pipeline_state=pipeline_state, generator=generator)) return result
def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline execution = pipeline_state.execution assert execution.last_known_state in (metadata_store_pb2.Execution.NEW, metadata_store_pb2.Execution.RUNNING) if execution.last_known_state != metadata_store_pb2.Execution.RUNNING: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING mlmd_handle.store.put_executions([updated_execution]) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, service_job_manager) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: # Create cancellation tasks for stop-initiated nodes if necessary. stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state) for node in stop_initiated_nodes: if service_job_manager.is_pure_service_node( pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, service_job_manager, set(n.node_info.id for n in stop_initiated_nodes)) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate() with pipeline_state: for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) elif task_lib.is_finalize_node_task(task): assert pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC task = typing.cast(task_lib.FinalizeNodeTask, task) pipeline_state.initiate_node_stop(task.node_uid, task.status) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)
def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline with pipeline_state: assert pipeline_state.is_active() if pipeline_state.get_pipeline_execution_state() != ( metadata_store_pb2.Execution.RUNNING): pipeline_state.set_pipeline_execution_state( metadata_store_pb2.Execution.RUNNING) orchestration_options = pipeline_state.get_orchestration_options() logging.info('Orchestration options: %s', orchestration_options) deadline_secs = orchestration_options.deadline_secs if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC and deadline_secs > 0 and time.time() - pipeline_state.pipeline_creation_time_secs_since_epoch() > deadline_secs): logging.error( 'Aborting pipeline due to exceeding deadline (%s secs); ' 'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid) pipeline_state.initiate_stop( status_lib.Status( code=status_lib.Code.DEADLINE_EXCEEDED, message=('Pipeline aborted due to exceeding deadline ' f'({deadline_secs} secs)'))) return def _filter_by_state(node_infos: List[_NodeInfo], state_str: str) -> List[_NodeInfo]: return [n for n in node_infos if n.state.state == state_str] node_infos = _get_node_infos(pipeline_state) stopping_node_infos = _filter_by_state(node_infos, pstate.NodeState.STOPPING) # Tracks nodes stopped in the current iteration. stopped_node_infos: List[_NodeInfo] = [] # Create cancellation tasks for nodes in state STOPPING. for node_info in stopping_node_infos: if service_job_manager.is_pure_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node_info.node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) else: stopped_node_infos.append(node_info) # Change the state of stopped nodes from STOPPING to STOPPED. if stopped_node_infos: with pipeline_state: for node_info in stopped_node_infos: node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, node_info.node) with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STOPPED, node_state.status) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager, fail_fast=orchestration_options.fail_fast) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate(pipeline_state) with pipeline_state: # Handle all the UpdateNodeStateTasks by updating node states. for task in tasks: if task_lib.is_update_node_state_task(task): task = typing.cast(task_lib.UpdateNodeStateTask, task) with pipeline_state.node_state_update_context( task.node_uid) as node_state: node_state.update(task.state, task.status) tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)] # If there are still nodes in state STARTING, change them to STARTED. for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline): node_uid = task_lib.NodeUid.from_pipeline_node( pipeline_state.pipeline, node) with pipeline_state.node_state_update_context( node_uid) as node_state: if node_state.state == pstate.NodeState.STARTING: node_state.update(pstate.NodeState.STARTED) for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)