def _abort_node_task( self, node_uid: task_lib.NodeUid) -> task_lib.FinalizeNodeTask: """Returns task to abort the node execution.""" logging.error( 'Required service node not running or healthy, node uid: %s', node_uid) return task_lib.FinalizeNodeTask( node_uid=node_uid, status=status_lib.Status( code=status_lib.Code.ABORTED, message=(f'Aborting node execution as the associated service ' f'job is not running or healthy; problematic node ' f'uid: {node_uid}')))
def generate(self) -> List[task_lib.Task]: """Generates tasks for all executable nodes in the async pipeline. The returned tasks must have `exec_task` populated. List may be empty if no nodes are ready for execution. Returns: A `list` of tasks to execute. """ result = [] for node in [n.pipeline_node for n in self._pipeline.nodes]: node_uid = task_lib.NodeUid.from_pipeline_node( self._pipeline, node) node_id = node.node_info.id if node_id in self._ignore_node_ids: logging.info('Ignoring node for task generation: %s', node_uid) continue if self._service_job_manager.is_pure_service_node( self._pipeline_state, node_id): service_status = self._service_job_manager.ensure_node_services( self._pipeline_state, node_id) if service_status != service_jobs.ServiceStatus.RUNNING: logging.error( 'Required service node not running or healthy, node uid: %s', node_uid) result.append( task_lib.FinalizeNodeTask( node_uid=node_uid, status=status_lib.Status( code=status_lib.Code.ABORTED, message= (f'Aborting node execution as the associated service ' f'job is not running or healthy; problematic node ' f'uid: {node_uid}')))) continue # If a task for the node is already tracked by the task queue, it need # not be considered for generation again. if self._is_task_id_tracked_fn( task_lib.exec_node_task_id_from_pipeline_node( self._pipeline, node)): continue task = self._generate_task(self._mlmd_handle, node) if task: result.append(task) return result
def test_handling_finalize_node_task(self, task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED, message='foo bar') task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Transform')), task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid( pipeline_uid=pipeline_uid, node_id='Trainer'), status=finalize_reason) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) task_gen.return_value.generate.assert_called_once() task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid) # Load pipeline state and verify node stop initiation. with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual( finalize_reason, pipeline_state.node_stop_initiated_reason( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer')))