Exemplo n.º 1
0
 def _abort_node_task(
         self, node_uid: task_lib.NodeUid) -> task_lib.FinalizeNodeTask:
     """Returns task to abort the node execution."""
     logging.error(
         'Required service node not running or healthy, node uid: %s',
         node_uid)
     return task_lib.FinalizeNodeTask(
         node_uid=node_uid,
         status=status_lib.Status(
             code=status_lib.Code.ABORTED,
             message=(f'Aborting node execution as the associated service '
                      f'job is not running or healthy; problematic node '
                      f'uid: {node_uid}')))
Exemplo n.º 2
0
    def generate(self) -> List[task_lib.Task]:
        """Generates tasks for all executable nodes in the async pipeline.

    The returned tasks must have `exec_task` populated. List may be empty if no
    nodes are ready for execution.

    Returns:
      A `list` of tasks to execute.
    """
        result = []
        for node in [n.pipeline_node for n in self._pipeline.nodes]:
            node_uid = task_lib.NodeUid.from_pipeline_node(
                self._pipeline, node)
            node_id = node.node_info.id
            if node_id in self._ignore_node_ids:
                logging.info('Ignoring node for task generation: %s', node_uid)
                continue

            if self._service_job_manager.is_pure_service_node(
                    self._pipeline_state, node_id):
                service_status = self._service_job_manager.ensure_node_services(
                    self._pipeline_state, node_id)
                if service_status != service_jobs.ServiceStatus.RUNNING:
                    logging.error(
                        'Required service node not running or healthy, node uid: %s',
                        node_uid)
                    result.append(
                        task_lib.FinalizeNodeTask(
                            node_uid=node_uid,
                            status=status_lib.Status(
                                code=status_lib.Code.ABORTED,
                                message=
                                (f'Aborting node execution as the associated service '
                                 f'job is not running or healthy; problematic node '
                                 f'uid: {node_uid}'))))
                continue

            # If a task for the node is already tracked by the task queue, it need
            # not be considered for generation again.
            if self._is_task_id_tracked_fn(
                    task_lib.exec_node_task_id_from_pipeline_node(
                        self._pipeline, node)):
                continue
            task = self._generate_task(self._mlmd_handle, node)
            if task:
                result.append(task)
        return result
Exemplo n.º 3
0
    def test_handling_finalize_node_task(self, task_gen):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline_ops.initiate_pipeline_start(m, pipeline)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED,
                                                message='foo bar')
            task_gen.return_value.generate.side_effect = [
                [
                    test_utils.create_exec_node_task(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Transform')),
                    task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid(
                        pipeline_uid=pipeline_uid, node_id='Trainer'),
                                              status=finalize_reason)
                ],
            ]

            task_queue = tq.TaskQueue()
            pipeline_ops.orchestrate(m, task_queue,
                                     service_jobs.DummyServiceJobManager())
            task_gen.return_value.generate.assert_called_once()
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_exec_node_task(task))
            self.assertEqual(
                test_utils.create_node_uid('pipeline1', 'Transform'),
                task.node_uid)

            # Load pipeline state and verify node stop initiation.
            with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
                self.assertEqual(
                    finalize_reason,
                    pipeline_state.node_stop_initiated_reason(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Trainer')))