def _process_stop_initiated_pipelines( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, pipeline_details: Sequence[_PipelineDetail]) -> None: """Processes stop initiated pipelines.""" for detail in pipeline_details: pipeline = detail.pipeline_state.pipeline execution = detail.pipeline_state.execution has_active_executions = False for node in _get_all_pipeline_nodes(pipeline): # If the node has an ExecNodeTask in the task queue, issue a cancellation. # Otherwise, if the node has an active execution in MLMD but no # ExecNodeTask enqueued, it may be due to orchestrator restart after # pipeline stop was initiated but before the schedulers could finish. So, # enqueue an ExecNodeTask with is_cancelled set to give a chance for the # scheduler to finish gracefully. exec_node_task_id = task_lib.exec_node_task_id_from_pipeline_node( pipeline, node) if task_queue.contains_task_id(exec_node_task_id): task_queue.enqueue( task_lib.CancelNodeTask( node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node))) has_active_executions = True else: executions = task_gen_utils.get_executions(mlmd_handle, node) exec_node_task = task_gen_utils.generate_task_from_active_execution( mlmd_handle, pipeline, node, executions, is_cancelled=True) if exec_node_task: task_queue.enqueue(exec_node_task) has_active_executions = True if not has_active_executions: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.CANCELED mlmd_handle.store.put_executions([updated_execution])
def test_task_ids(self): pipeline_uid = task_lib.PipelineUid(pipeline_id='pipeline') node_uid = task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer') exec_node_task = test_utils.create_exec_node_task(node_uid) self.assertEqual(('ExecNodeTask', node_uid), exec_node_task.task_id) cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid) self.assertEqual(('CancelNodeTask', node_uid), cancel_node_task.task_id)
def test_task_ids(self): node_uid = task_lib.NodeUid(pipeline_id='pipeline', pipeline_run_id='run0', node_id='Trainer') exec_node_task = task_lib.ExecNodeTask(node_uid=node_uid, execution_id=123) self.assertEqual(('ExecNodeTask', node_uid), exec_node_task.task_id) cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid) self.assertEqual(('CancelNodeTask', node_uid), cancel_node_task.task_id)
def test_task_type_ids(self): self.assertEqual('ExecNodeTask', task_lib.ExecNodeTask.task_type_id()) self.assertEqual('CancelNodeTask', task_lib.CancelNodeTask.task_type_id()) node_uid = task_lib.NodeUid(pipeline_id='pipeline', pipeline_run_id='run0', node_id='Trainer') exec_node_task = test_utils.create_exec_node_task(node_uid) self.assertEqual('ExecNodeTask', exec_node_task.task_type_id()) cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid) self.assertEqual('CancelNodeTask', cancel_node_task.task_type_id())
def _maybe_enqueue_cancellation_task(mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, node: pipeline_pb2.PipelineNode, task_queue: tq.TaskQueue, pause: bool = False) -> bool: """Enqueues a node cancellation task if not already stopped. If the node has an ExecNodeTask in the task queue, issue a cancellation. Otherwise, when pause=False, if the node has an active execution in MLMD but no ExecNodeTask enqueued, it may be due to orchestrator restart after stopping was initiated but before the schedulers could finish. So, enqueue an ExecNodeTask with is_cancelled set to give a chance for the scheduler to finish gracefully. Args: mlmd_handle: A handle to the MLMD db. pipeline: The pipeline containing the node to cancel. node: The node to cancel. task_queue: A `TaskQueue` instance into which any cancellation tasks will be enqueued. pause: Whether the cancellation is to pause the node rather than cancelling the execution. Returns: `True` if a cancellation task was enqueued. `False` if node is already stopped or no cancellation was required. """ exec_node_task_id = task_lib.exec_node_task_id_from_pipeline_node( pipeline, node) if task_queue.contains_task_id(exec_node_task_id): task_queue.enqueue( task_lib.CancelNodeTask( node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node), pause=pause)) return True if not pause: executions = task_gen_utils.get_executions(mlmd_handle, node) exec_node_task = task_gen_utils.generate_task_from_active_execution( mlmd_handle, pipeline, node, executions, is_cancelled=True) if exec_node_task: task_queue.enqueue(exec_node_task) return True return False
def _test_cancel_node_task(node_id, pipeline_id): node_uid = task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id), node_id=node_id) return task_lib.CancelNodeTask(node_uid=node_uid)
def _test_cancel_task(node_id, pipeline_id, pipeline_run_id=None): node_uid = task_lib.NodeUid(pipeline_id=pipeline_id, pipeline_run_id=pipeline_run_id, node_id=node_id) return task_lib.CancelNodeTask(node_uid=node_uid)