def _run_next(self, use_task_queue, expect_nodes, finish_nodes=None, artifact_custom_properties=None, fail_fast=False): """Runs a complete cycle of task generation and simulating their completion. Args: use_task_queue: Whether to use task queue. expect_nodes: List of nodes whose task generation is expected. finish_nodes: List of nodes whose completion should be simulated. If `None` (default), all of `expect_nodes` will be finished. artifact_custom_properties: A dict of custom properties to attach to the output artifacts. fail_fast: If `True`, pipeline is aborted immediately if any node fails. """ tasks = self._generate(use_task_queue, True, fail_fast=fail_fast) for task in tasks: self.assertTrue(task_lib.is_exec_node_task(task)) expected_node_ids = [n.node_info.id for n in expect_nodes] task_node_ids = [task.node_uid.node_id for task in tasks] self.assertCountEqual(expected_node_ids, task_node_ids) finish_node_ids = set([n.node_info.id for n in finish_nodes] if finish_nodes is not None else expected_node_ids) for task in tasks: if task.node_uid.node_id in finish_node_ids: self._finish_node_execution( use_task_queue, task, artifact_custom_properties=artifact_custom_properties)
def test_service_job_success(self): """Tests task generation when example-gen service job succeeds.""" test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) [ eg_update_node_state_task, sg_update_node_state_task, sg_exec_node_task ] = self._generate_and_test(True, num_initial_executions=1, num_tasks_generated=3, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._stats_gen]) self.assertTrue( task_lib.is_update_node_state_task(eg_update_node_state_task)) self.assertEqual('my_example_gen', eg_update_node_state_task.node_uid.node_id) self.assertEqual(pstate.NodeState.COMPLETE, eg_update_node_state_task.state) self.assertTrue( task_lib.is_update_node_state_task(sg_update_node_state_task)) self.assertEqual('my_statistics_gen', sg_update_node_state_task.node_uid.node_id) self.assertEqual(pstate.NodeState.RUNNING, sg_update_node_state_task.state) self.assertTrue(task_lib.is_exec_node_task(sg_exec_node_task))
def _generate_and_test(self, use_task_queue, num_initial_executions, num_tasks_generated, num_new_executions, num_active_executions): """Generates tasks and tests the effects.""" with self._mlmd_connection as m: executions = m.store.get_executions() self.assertLen( executions, num_initial_executions, 'Expected {} execution(s) in MLMD.'.format(num_initial_executions)) pipeline_state = pstate.PipelineState.new(m, self._pipeline) task_gen = sptg.SyncPipelineTaskGenerator( m, pipeline_state, self._task_queue.contains_task_id, self._mock_service_job_manager) tasks = task_gen.generate() self.assertLen( tasks, num_tasks_generated, 'Expected {} task(s) to be generated.'.format(num_tasks_generated)) executions = m.store.get_executions() num_total_executions = num_initial_executions + num_new_executions self.assertLen( executions, num_total_executions, 'Expected {} execution(s) in MLMD.'.format(num_total_executions)) active_executions = [ e for e in executions if e.last_known_state == metadata_store_pb2.Execution.RUNNING ] self.assertLen( active_executions, num_active_executions, 'Expected {} active execution(s) in MLMD.'.format( num_active_executions)) if use_task_queue: for task in tasks: if task_lib.is_exec_node_task(task): self._task_queue.enqueue(task) return tasks, active_executions
def run_generator(mlmd_connection, generator_class, pipeline, task_queue, use_task_queue, service_job_manager, ignore_update_node_state_tasks=False, fail_fast=None): """Generates tasks for testing.""" with mlmd_connection as m: pipeline_state = get_or_create_pipeline_state(m, pipeline) generator_params = dict( mlmd_handle=m, is_task_id_tracked_fn=task_queue.contains_task_id, service_job_manager=service_job_manager) if fail_fast is not None: generator_params['fail_fast'] = fail_fast task_gen = generator_class(**generator_params) tasks = task_gen.generate(pipeline_state) if use_task_queue: for task in tasks: if task_lib.is_exec_node_task(task): task_queue.enqueue(task) for task in tasks: if task_lib.is_update_node_state_task(task): with pipeline_state: with pipeline_state.node_state_update_context( task.node_uid) as node_state: node_state.update(task.state, task.status) if ignore_update_node_state_tasks: tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)] return tasks
def _handle_task(self, task: task_lib.Task) -> None: """Dispatches task to the task specific handler.""" if task_lib.is_exec_node_task(task): self._handle_exec_node_task(typing.cast(task_lib.ExecNodeTask, task)) elif task_lib.is_cancel_node_task(task): self._handle_cancel_node_task(typing.cast(task_lib.CancelNodeTask, task)) else: raise RuntimeError('Cannot dispatch bad task: {}'.format(task))
def _make_executor_output(task, code=status_lib.Code.OK): assert task_lib.is_exec_node_task(task) executor_output = execution_result_pb2.ExecutorOutput() for key, artifacts in task.output_artifacts.items(): for artifact in artifacts: executor_output.output_artifacts[key].artifacts.add().CopyFrom( artifact.mlmd_artifact) executor_output.execution_result.code = code return executor_output
def run_generator_and_test(test_case, mlmd_connection, generator_class, pipeline, task_queue, use_task_queue, service_job_manager, num_initial_executions, num_tasks_generated, num_new_executions, num_active_executions, expected_exec_nodes=None, ignore_node_ids=None): """Runs generator.generate() and tests the effects.""" if service_job_manager is None: service_job_manager = service_jobs.DummyServiceJobManager() with mlmd_connection as m: executions = m.store.get_executions() test_case.assertLen( executions, num_initial_executions, f'Expected {num_initial_executions} execution(s) in MLMD.') pipeline_state = pstate.PipelineState.new(m, pipeline) generator_params = dict( mlmd_handle=m, pipeline_state=pipeline_state, is_task_id_tracked_fn=task_queue.contains_task_id, service_job_manager=service_job_manager) if generator_class == asptg.AsyncPipelineTaskGenerator: generator_params['ignore_node_ids'] = ignore_node_ids task_gen = generator_class(**generator_params) tasks = task_gen.generate() test_case.assertLen( tasks, num_tasks_generated, f'Expected {num_tasks_generated} task(s) to be generated.') executions = m.store.get_executions() num_total_executions = num_initial_executions + num_new_executions test_case.assertLen( executions, num_total_executions, f'Expected {num_total_executions} execution(s) in MLMD.') active_executions = [ e for e in executions if execution_lib.is_execution_active(e) ] test_case.assertLen( active_executions, num_active_executions, f'Expected {num_active_executions} active execution(s) in MLMD.') if expected_exec_nodes: for i, task in enumerate(tasks): _verify_exec_node_task(test_case, pipeline, expected_exec_nodes[i], active_executions[i].id, task) if use_task_queue: for task in tasks: if task_lib.is_exec_node_task(task): task_queue.enqueue(task) return tasks
def run_generator_and_test(test_case, mlmd_connection, generator_class, pipeline, task_queue, use_task_queue, service_job_manager, num_initial_executions, num_tasks_generated, num_new_executions, num_active_executions, expected_exec_nodes=None, ignore_update_node_state_tasks=False, fail_fast=None): """Runs generator.generate() and tests the effects.""" if service_job_manager is None: service_job_manager = service_jobs.DummyServiceJobManager() with mlmd_connection as m: executions = get_non_orchestrator_executions(m) test_case.assertLen( executions, num_initial_executions, f'Expected {num_initial_executions} execution(s) in MLMD.') tasks = run_generator( mlmd_connection, generator_class, pipeline, task_queue, use_task_queue, service_job_manager, ignore_update_node_state_tasks=ignore_update_node_state_tasks, fail_fast=fail_fast) with mlmd_connection as m: test_case.assertLen( tasks, num_tasks_generated, f'Expected {num_tasks_generated} task(s) to be generated.') executions = get_non_orchestrator_executions(m) num_total_executions = num_initial_executions + num_new_executions test_case.assertLen( executions, num_total_executions, f'Expected {num_total_executions} execution(s) in MLMD.') active_executions = [ e for e in executions if execution_lib.is_execution_active(e) ] test_case.assertLen( active_executions, num_active_executions, f'Expected {num_active_executions} active execution(s) in MLMD.') if expected_exec_nodes: for i, task in enumerate(t for t in tasks if task_lib.is_exec_node_task(t)): _verify_exec_node_task(test_case, pipeline, expected_exec_nodes[i], active_executions[i].id, task) return tasks
def create_task_scheduler(cls: Type[T], mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, task: task_lib.Task) -> TaskScheduler: """Creates a task scheduler for the given task. The task is matched as follows: 1. The node type name of the node associated with the task is looked up in the registry and a scheduler is instantiated if present. 2. Next, the executor spec url of the node (if one exists) is looked up in the registry and a scheduler is instantiated if present. This assumes deployment_config packed in the pipeline IR is of type `IntermediateDeploymentConfig`. 3. Lastly, a ValueError is raised if no match can be found. Args: mlmd_handle: A handle to the MLMD db. pipeline: The pipeline IR. task: The task that needs to be scheduled. Returns: An instance of `TaskScheduler` for the given task. Raises: NotImplementedError: Raised if not an `ExecNodeTask`. ValueError: If a scheduler could not be found in the registry for the given task. """ if not task_lib.is_exec_node_task(task): raise NotImplementedError( 'Can create a task scheduler only for an `ExecNodeTask`.') task = typing.cast(task_lib.ExecNodeTask, task) try: scheduler_class = cls._scheduler_class_for_node_type(task) except ValueError as e1: try: scheduler_class = cls._scheduler_class_for_executor_spec( pipeline, task) except ValueError as e2: raise ValueError( f'No task scheduler found: {e1}, {e2}') from None return scheduler_class(mlmd_handle=mlmd_handle, pipeline=pipeline, task=task)
def test_task_generation_when_node_stopped(self, stop_transform): """Tests stopped nodes are ignored when generating tasks.""" # Simulate that ExampleGen has already completed successfully. test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Generate once. num_initial_executions = 1 if stop_transform: num_tasks_generated = 1 num_new_executions = 0 num_active_executions = 0 with self._mlmd_connection as m: pipeline_state = test_utils.get_or_create_pipeline_state( m, self._pipeline) with pipeline_state: with pipeline_state.node_state_update_context( task_lib.NodeUid.from_pipeline_node( self._pipeline, self._transform)) as node_state: node_state.update(pstate.NodeState.STOPPING, status_lib.Status(code=status_lib.Code.CANCELLED)) else: num_tasks_generated = 3 num_new_executions = 1 num_active_executions = 1 tasks = self._generate_and_test( True, num_initial_executions=num_initial_executions, num_tasks_generated=num_tasks_generated, num_new_executions=num_new_executions, num_active_executions=num_active_executions) self.assertLen(tasks, num_tasks_generated) if stop_transform: self.assertTrue(task_lib.is_update_node_state_task(tasks[0])) self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state) else: self.assertTrue(task_lib.is_update_node_state_task(tasks[0])) self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state) self.assertTrue(task_lib.is_update_node_state_task(tasks[1])) self.assertEqual(pstate.NodeState.RUNNING, tasks[1].state) self.assertTrue(task_lib.is_exec_node_task(tasks[2]))
def test_node_success(self): """Tests task generation when a node execution succeeds.""" test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) [stats_gen_task ] = self._generate_and_test(False, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, ignore_update_node_state_tasks=True) # Finish stats-gen execution. self._finish_node_execution(False, stats_gen_task) [ stats_gen_update_node_state_task, schema_gen_update_node_state_task, schema_gen_exec_node_task ] = self._generate_and_test(False, num_initial_executions=2, num_tasks_generated=3, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._schema_gen]) self.assertTrue( task_lib.is_update_node_state_task( stats_gen_update_node_state_task)) self.assertEqual('my_statistics_gen', stats_gen_update_node_state_task.node_uid.node_id) self.assertEqual(pstate.NodeState.COMPLETE, stats_gen_update_node_state_task.state) self.assertTrue( task_lib.is_update_node_state_task( schema_gen_update_node_state_task)) self.assertEqual('my_schema_gen', schema_gen_update_node_state_task.node_uid.node_id) self.assertEqual(pstate.NodeState.RUNNING, schema_gen_update_node_state_task.state) self.assertTrue(task_lib.is_exec_node_task(schema_gen_exec_node_task))
def create_task_scheduler(cls: Type[T], mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, task: task_lib.Task) -> TaskScheduler: """Creates a task scheduler for the given task. Note that this assumes deployment_config packed in the pipeline IR is of type `IntermediateDeploymentConfig`. This detail may change in the future. Args: mlmd_handle: A handle to the MLMD db. pipeline: The pipeline IR. task: The task that needs to be scheduled. Returns: An instance of `TaskScheduler` for the given task. Raises: NotImplementedError: Raised if not an `ExecNodeTask`. ValueError: Deployment config not present in the IR proto or if executor spec for the node corresponding to `task` not configured in the IR. """ if not task_lib.is_exec_node_task(task): raise NotImplementedError( 'Can create a task scheduler only for an `ExecNodeTask`.') task = typing.cast(task_lib.ExecNodeTask, task) # TODO(b/170383494): Decide which DeploymentConfig to use. if not pipeline.deployment_config.Is( pipeline_pb2.IntermediateDeploymentConfig.DESCRIPTOR): raise ValueError('No deployment config found in pipeline IR.') depl_config = pipeline_pb2.IntermediateDeploymentConfig() pipeline.deployment_config.Unpack(depl_config) node_id = task.node_uid.node_id if node_id not in depl_config.executor_specs: raise ValueError( 'Executor spec for node id `{}` not found in pipeline IR.'. format(node_id)) executor_spec_type_url = depl_config.executor_specs[node_id].type_url return cls._task_scheduler_registry[executor_spec_type_url]( mlmd_handle=mlmd_handle, pipeline=pipeline, task=task)
def test_handling_finalize_node_task(self, task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED, message='foo bar') task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Transform')), task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid( pipeline_uid=pipeline_uid, node_id='Trainer'), status=finalize_reason) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) task_gen.return_value.generate.assert_called_once() task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid) # Load pipeline state and verify node stop initiation. with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual( finalize_reason, pipeline_state.node_stop_initiated_reason( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer')))
def _publish(**kwargs): task = kwargs['task'] assert task_lib.is_exec_node_task(task) if task.node_uid.node_id == 'Transform': raise ValueError('test error') return mock.DEFAULT
def test_task_generation(self, use_task_queue): """Tests async pipeline task generation. Args: use_task_queue: If task queue is enabled, new tasks are only generated if a task with the same task_id does not already exist in the queue. `use_task_queue=False` is useful to test the case of task generation when task queue is empty (for eg: due to orchestrator restart). """ # Simulate that ExampleGen has already completed successfully. test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Generate once. [update_example_gen_task, update_transform_task, exec_transform_task] = self._generate_and_test( use_task_queue, num_initial_executions=1, num_tasks_generated=3, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._transform]) self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task)) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_transform_task)) self.assertEqual(pstate.NodeState.RUNNING, update_transform_task.state) self.assertTrue(task_lib.is_exec_node_task(exec_transform_task)) self._mock_service_job_manager.ensure_node_services.assert_has_calls([ mock.call(mock.ANY, self._example_gen.node_info.id), mock.call(mock.ANY, self._transform.node_info.id) ]) # No new effects if generate called again. tasks = self._generate_and_test( use_task_queue, num_initial_executions=2, num_tasks_generated=1 if use_task_queue else 3, num_new_executions=0, num_active_executions=1, expected_exec_nodes=[] if use_task_queue else [self._transform]) if not use_task_queue: exec_transform_task = tasks[2] # Mark transform execution complete. self._finish_node_execution(use_task_queue, exec_transform_task) # Trainer execution task should be generated next. [ update_example_gen_task, update_transform_task, update_trainer_task, exec_trainer_task ] = self._generate_and_test( use_task_queue, num_initial_executions=2, num_tasks_generated=4, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._trainer]) self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task)) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_transform_task)) self.assertEqual(pstate.NodeState.STARTED, update_transform_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_trainer_task)) self.assertEqual(pstate.NodeState.RUNNING, update_trainer_task.state) self.assertTrue(task_lib.is_exec_node_task(exec_trainer_task)) # Mark the trainer execution complete. self._finish_node_execution(use_task_queue, exec_trainer_task) # Only UpdateNodeStateTask are generated as there are no new inputs. tasks = self._generate_and_test( use_task_queue, num_initial_executions=3, num_tasks_generated=3, num_new_executions=0, num_active_executions=0) for task in tasks: self.assertTrue(task_lib.is_update_node_state_task(task)) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) # Fake another ExampleGen run. test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Both transform and trainer tasks should be generated as they both find # new inputs. [ update_example_gen_task, update_transform_task, exec_transform_task, update_trainer_task, exec_trainer_task ] = self._generate_and_test( use_task_queue, num_initial_executions=4, num_tasks_generated=5, num_new_executions=2, num_active_executions=2, expected_exec_nodes=[self._transform, self._trainer]) self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task)) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_transform_task)) self.assertEqual(pstate.NodeState.RUNNING, update_transform_task.state) self.assertTrue(task_lib.is_exec_node_task(exec_transform_task)) self.assertTrue(task_lib.is_update_node_state_task(update_trainer_task)) self.assertEqual(pstate.NodeState.RUNNING, update_trainer_task.state) self.assertTrue(task_lib.is_exec_node_task(exec_trainer_task)) # Re-generation will produce the same tasks when task queue disabled. tasks = self._generate_and_test( use_task_queue, num_initial_executions=6, num_tasks_generated=1 if use_task_queue else 5, num_new_executions=0, num_active_executions=2, expected_exec_nodes=[] if use_task_queue else [self._transform, self._trainer]) if not use_task_queue: self.assertTrue(task_lib.is_update_node_state_task(tasks[0])) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_update_node_state_task(tasks[1])) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_exec_node_task(tasks[2])) self.assertTrue(task_lib.is_update_node_state_task(tasks[3])) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_exec_node_task(tasks[4])) exec_transform_task = tasks[2] exec_trainer_task = tasks[4] else: self.assertTrue(task_lib.is_update_node_state_task(tasks[0])) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) # Mark transform execution complete. self._finish_node_execution(use_task_queue, exec_transform_task) # Mark the trainer execution complete. self._finish_node_execution(use_task_queue, exec_trainer_task) # Trainer should be triggered again due to transform producing new output. [ update_example_gen_task, update_transform_task, update_trainer_task, exec_trainer_task ] = self._generate_and_test( use_task_queue, num_initial_executions=6, num_tasks_generated=4, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._trainer]) self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task)) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_transform_task)) self.assertEqual(pstate.NodeState.STARTED, update_transform_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_trainer_task)) self.assertEqual(pstate.NodeState.RUNNING, update_trainer_task.state) self.assertTrue(task_lib.is_exec_node_task(exec_trainer_task)) # Finally, no new tasks once trainer completes. self._finish_node_execution(use_task_queue, exec_trainer_task) [update_example_gen_task, update_transform_task, update_trainer_task] = self._generate_and_test( use_task_queue, num_initial_executions=7, num_tasks_generated=3, num_new_executions=0, num_active_executions=0) self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task)) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_transform_task)) self.assertEqual(pstate.NodeState.STARTED, update_transform_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_trainer_task)) self.assertEqual(pstate.NodeState.STARTED, update_trainer_task.state) if use_task_queue: self.assertTrue(self._task_queue.is_empty())
def test_stop_initiated_async_pipelines(self, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline1 = _test_pipeline('pipeline1') pipeline1.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline1.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline1.nodes.add().pipeline_node.node_info.id = 'Evaluator' pipeline_ops.initiate_pipeline_start(m, pipeline1) pipeline1_execution = pipeline_ops._initiate_pipeline_stop( m, task_lib.PipelineUid.from_pipeline(pipeline1)) task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Transform'))) transform_task = task_queue.dequeue( ) # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.generate_tasks(m, task_queue) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # Simulate finishing the "Transform" ExecNodeTask. task_queue.task_done(transform_task) # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call(m, pipeline1, pipeline1.nodes[1].pipeline_node, mock.ANY, is_cancelled=True), mock.call(m, pipeline1, pipeline1.nodes[2].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `generate_tasks`. [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `generate_tasks` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.generate_tasks(m, task_queue) self.assertTrue(task_queue.is_empty()) [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state)
def setUp(self): super().setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = test_async_pipeline.create_pipeline() # Extracts components. self._example_gen = pipeline.nodes[0].pipeline_node self._transform = pipeline.nodes[1].pipeline_node self._trainer = pipeline.nodes[2].pipeline_node # Pack deployment config for testing. deployment_config = pipeline_pb2.IntermediateDeploymentConfig() executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec( class_path='fake.ClassPath') deployment_config.executor_specs[self._trainer.node_info.id].Pack( executor_spec) deployment_config.executor_specs[self._transform.node_info.id].Pack( executor_spec) self._type_url = deployment_config.executor_specs[ self._trainer.node_info.id].type_url pipeline.deployment_config.Pack(deployment_config) self._pipeline = pipeline self._pipeline_info = pipeline.pipeline_info self._pipeline_runtime_spec = pipeline.runtime_spec self._pipeline_runtime_spec.pipeline_root.field_value.string_value = ( pipeline_root) ts.TaskSchedulerRegistry.clear() self._task_queue = tq.TaskQueue() # Run fake example-gen to prepare downstreams component triggers. test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Task generator should produce two tasks for transform. The first one is # UpdateNodeStateTask and the second one is ExecNodeTask. with self._mlmd_connection as m: pipeline_state = pstate.PipelineState.new(m, self._pipeline) tasks = asptg.AsyncPipelineTaskGenerator( m, self._task_queue.contains_task_id, service_jobs.DummyServiceJobManager()).generate(pipeline_state) self.assertLen(tasks, 2) self.assertTrue(task_lib.is_update_node_state_task(tasks[0])) self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state) self.assertEqual('my_transform', tasks[0].node_uid.node_id) self.assertTrue(task_lib.is_exec_node_task(tasks[1])) self.assertEqual('my_transform', tasks[1].node_uid.node_id) self.assertTrue(os.path.exists(tasks[1].stateful_working_dir)) self.assertTrue(os.path.exists(tasks[1].tmp_dir)) self._task = tasks[1] self._output_artifact_uri = self._task.output_artifacts[ 'transform_graph'][0].uri self.assertTrue(os.path.exists(self._output_artifact_uri)) self._task_queue.enqueue(self._task) # There should be 1 active execution in MLMD. with self._mlmd_connection as m: executions = m.store.get_executions() active_executions = [ e for e in executions if e.last_known_state == metadata_store_pb2.Execution.RUNNING ] self.assertLen(active_executions, 1) # Active execution id. self._execution_id = active_executions[0].id
def test_generate_tasks_async_active_pipelines(self, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: # One active pipeline. pipeline1 = _test_pipeline('pipeline1') pipeline_ops.initiate_pipeline_start(m, pipeline1) # Another active pipeline (with previously completed execution). pipeline2 = _test_pipeline('pipeline2') execution2 = pipeline_ops.initiate_pipeline_start(m, pipeline2) execution2.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution2]) execution2 = pipeline_ops.initiate_pipeline_start(m, pipeline2) # Inactive pipelines should be ignored. pipeline3 = _test_pipeline('pipeline3') execution3 = pipeline_ops.initiate_pipeline_start(m, pipeline3) execution3.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution3]) # For active pipelines pipeline1 and pipeline2, there are a couple of # active executions. def _exec_node_tasks(): for pipeline_id in ('pipeline1', 'pipeline2'): yield [ test_utils.create_exec_node_task( node_uid=task_lib. NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id=pipeline_id, pipeline_run_id=None), node_id='Transform')), test_utils.create_exec_node_task( node_uid=task_lib. NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id=pipeline_id, pipeline_run_id=None), node_id='Trainer')) ] mock_async_task_gen.return_value.generate.side_effect = _exec_node_tasks( ) task_queue = tq.TaskQueue() pipeline_ops.generate_tasks(m, task_queue) self.assertEqual( 2, mock_async_task_gen.return_value.generate.call_count) mock_sync_task_gen.assert_not_called() # Verify that tasks are enqueued in the expected order. for node_id in ('Transform', 'Trainer'): task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual(node_id, task.node_uid.node_id) self.assertEqual('pipeline1', task.node_uid.pipeline_uid.pipeline_id) for node_id in ('Transform', 'Trainer'): task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual(node_id, task.node_uid.node_id) self.assertEqual('pipeline2', task.node_uid.pipeline_uid.pipeline_id) self.assertTrue(task_queue.is_empty())
def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline execution = pipeline_state.execution assert execution.last_known_state in (metadata_store_pb2.Execution.NEW, metadata_store_pb2.Execution.RUNNING) if execution.last_known_state != metadata_store_pb2.Execution.RUNNING: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING mlmd_handle.store.put_executions([updated_execution]) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, service_job_manager) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: # Create cancellation tasks for stop-initiated nodes if necessary. stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state) for node in stop_initiated_nodes: if service_job_manager.is_pure_service_node( pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, service_job_manager, set(n.node_info.id for n in stop_initiated_nodes)) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate() with pipeline_state: for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) elif task_lib.is_finalize_node_task(task): assert pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC task = typing.cast(task_lib.FinalizeNodeTask, task) pipeline_state.initiate_node_stop(task.node_uid, task.status) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)
def test_conditional_execution(self, evaluate): """Tests conditionals in the pipeline. Args: evaluate: Whether to run the conditional evaluator. """ # Check the expected terminal nodes. layers = sptg._topsorted_layers(self._pipeline) self.assertEqual( { self._example_validator.node_info.id, self._chore_b.node_info.id, self._evaluator.node_info.id, }, sptg._terminal_node_ids(layers)) # Start executing the pipeline: test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) self._run_next(False, expect_nodes=[self._stats_gen]) self._run_next(False, expect_nodes=[self._schema_gen]) self._run_next(False, expect_nodes=[self._example_validator, self._transform]) # Evaluator is run conditionally based on whether the Model artifact # produced by the trainer has a custom property evaluate=1. self._run_next( False, expect_nodes=[self._trainer], artifact_custom_properties={'evaluate': 1} if evaluate else None) tasks = self._generate(False) [evaluator_update_node_state_task] = [ t for t in tasks if task_lib.is_update_node_state_task(t) and t.node_uid.node_id == 'my_evaluator' ] self.assertEqual( pstate.NodeState.RUNNING if evaluate else pstate.NodeState.SKIPPED, evaluator_update_node_state_task.state) exec_node_tasks = [t for t in tasks if task_lib.is_exec_node_task(t)] if evaluate: [chore_a_exec_node_task, evaluator_exec_node_task] = exec_node_tasks self.assertEqual('chore_a', chore_a_exec_node_task.node_uid.node_id) self.assertEqual('my_evaluator', evaluator_exec_node_task.node_uid.node_id) self._finish_node_execution(False, chore_a_exec_node_task) self._finish_node_execution(False, evaluator_exec_node_task) else: [chore_a_exec_node_task] = exec_node_tasks self.assertEqual('chore_a', chore_a_exec_node_task.node_uid.node_id) self._finish_node_execution(False, chore_a_exec_node_task) self._run_next(False, expect_nodes=[self._chore_b]) # All nodes executed, finalization task should be produced. [finalize_task] = self._generate(False, True) self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task))
def __call__(self) -> List[task_lib.Task]: layers = _topsorted_layers(self._pipeline) terminal_node_ids = _terminal_node_ids(layers) exec_node_tasks = [] update_node_state_tasks = [] successful_node_ids = set() failed_nodes_dict: Dict[str, status_lib.Status] = {} finalize_pipeline_task = None for layer_nodes in layers: for node in layer_nodes: node_id = node.node_info.id node_uid = task_lib.NodeUid.from_pipeline_node( self._pipeline, node) node_state = self._node_states_dict[node_uid] if node_state.is_success(): successful_node_ids.add(node_id) continue if node_state.is_failure(): failed_nodes_dict[node_id] = node_state.status continue if not self._upstream_nodes_successful(node, successful_node_ids): continue tasks = self._generate_tasks_for_node(node) for task in tasks: if task_lib.is_update_node_state_task(task): task = typing.cast(task_lib.UpdateNodeStateTask, task) if pstate.is_node_state_success(task.state): successful_node_ids.add(node_id) elif pstate.is_node_state_failure(task.state): failed_nodes_dict[node_id] = task.status if self._fail_fast: finalize_pipeline_task = self._abort_task( task.status.message) update_node_state_tasks.append(task) elif task_lib.is_exec_node_task(task): exec_node_tasks.append(task) if finalize_pipeline_task: break if finalize_pipeline_task: break if not self._fail_fast and failed_nodes_dict: assert not finalize_pipeline_task node_by_id = _node_by_id(self._pipeline) # Collect nodes that cannot be run because they have a failed ancestor. unrunnable_node_ids = set() for node_id in failed_nodes_dict: unrunnable_node_ids |= _descendants(node_by_id, node_id) # Nodes that are still runnable have neither succeeded nor failed, and # don't have a failed ancestor. runnable_node_ids = node_by_id.keys() - (unrunnable_node_ids | successful_node_ids | failed_nodes_dict.keys()) # If there are no runnable nodes, we can abort the pipeline. if not runnable_node_ids: finalize_pipeline_task = self._abort_task( f'Cannot make progress due to node failures: {failed_nodes_dict}' ) result = update_node_state_tasks if finalize_pipeline_task: result.append(finalize_pipeline_task) elif terminal_node_ids <= successful_node_ids: # If all terminal nodes are successful, the pipeline can be finalized. result.append( task_lib.FinalizePipelineTask( pipeline_uid=self._pipeline_uid, status=status_lib.Status(code=status_lib.Code.OK))) else: result.extend(exec_node_tasks) return result
def test_orchestrate_active_pipelines(self, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: # Sync and async active pipelines. async_pipelines = [ _test_pipeline('pipeline1'), _test_pipeline('pipeline2'), ] sync_pipelines = [ _test_pipeline('pipeline3', pipeline_pb2.Pipeline.SYNC), _test_pipeline('pipeline4', pipeline_pb2.Pipeline.SYNC), ] for pipeline in async_pipelines + sync_pipelines: pipeline_ops.initiate_pipeline_start(m, pipeline) # Active executions for active async pipelines. mock_async_task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( async_pipelines[0]), node_id='Transform')) ], [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( async_pipelines[1]), node_id='Trainer')) ], ] # Active executions for active sync pipelines. mock_sync_task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( sync_pipelines[0]), node_id='Trainer')) ], [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( sync_pipelines[1]), node_id='Validator')) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count) self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count) # Verify that tasks are enqueued in the expected order. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline2', 'Trainer'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline3', 'Trainer'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline4', 'Validator'), task.node_uid) self.assertTrue(task_queue.is_empty())
def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == 'ExampleGen') mock_service_job_manager.is_mixed_service_node.side_effect = ( lambda _, node_id: node_id == 'Transform') pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: pipeline_state.initiate_stop( status_lib.Status(code=status_lib.Code.CANCELLED)) pipeline_execution = pipeline_state.execution task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), node_id='Transform'))) transform_task = task_queue.dequeue() # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task( node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # stop_node_services should be called for ExampleGen which is a pure # service node. mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') mock_service_job_manager.reset_mock() task_queue.task_done(transform_task) # Pop out transform task. # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask (with is_cancelled=True) for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task.is_cancelled) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call( m, pipeline_state.pipeline, pipeline.nodes[2].pipeline_node, mock.ANY, is_cancelled=True), mock.call( m, pipeline_state.pipeline, pipeline.nodes[3].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `orchestrate`. [execution] = m.store.get_executions_by_id([pipeline_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `orchestrate` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) self.assertTrue(task_queue.is_empty()) [execution] = m.store.get_executions_by_id([pipeline_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state) # stop_node_services should be called on both ExampleGen and Transform # which are service nodes. mock_service_job_manager.stop_node_services.assert_has_calls( [mock.call(mock.ANY, 'ExampleGen'), mock.call(mock.ANY, 'Transform')], any_order=True)
def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline with pipeline_state: assert pipeline_state.is_active() if pipeline_state.get_pipeline_execution_state() != ( metadata_store_pb2.Execution.RUNNING): pipeline_state.set_pipeline_execution_state( metadata_store_pb2.Execution.RUNNING) orchestration_options = pipeline_state.get_orchestration_options() logging.info('Orchestration options: %s', orchestration_options) deadline_secs = orchestration_options.deadline_secs if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC and deadline_secs > 0 and time.time() - pipeline_state.pipeline_creation_time_secs_since_epoch() > deadline_secs): logging.error( 'Aborting pipeline due to exceeding deadline (%s secs); ' 'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid) pipeline_state.initiate_stop( status_lib.Status( code=status_lib.Code.DEADLINE_EXCEEDED, message=('Pipeline aborted due to exceeding deadline ' f'({deadline_secs} secs)'))) return def _filter_by_state(node_infos: List[_NodeInfo], state_str: str) -> List[_NodeInfo]: return [n for n in node_infos if n.state.state == state_str] node_infos = _get_node_infos(pipeline_state) stopping_node_infos = _filter_by_state(node_infos, pstate.NodeState.STOPPING) # Tracks nodes stopped in the current iteration. stopped_node_infos: List[_NodeInfo] = [] # Create cancellation tasks for nodes in state STOPPING. for node_info in stopping_node_infos: if service_job_manager.is_pure_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node_info.node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) else: stopped_node_infos.append(node_info) # Change the state of stopped nodes from STOPPING to STOPPED. if stopped_node_infos: with pipeline_state: for node_info in stopped_node_infos: node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, node_info.node) with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STOPPED, node_state.status) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager, fail_fast=orchestration_options.fail_fast) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate(pipeline_state) with pipeline_state: # Handle all the UpdateNodeStateTasks by updating node states. for task in tasks: if task_lib.is_update_node_state_task(task): task = typing.cast(task_lib.UpdateNodeStateTask, task) with pipeline_state.node_state_update_context( task.node_uid) as node_state: node_state.update(task.state, task.status) tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)] # If there are still nodes in state STARTING, change them to STARTED. for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline): node_uid = task_lib.NodeUid.from_pipeline_node( pipeline_state.pipeline, node) with pipeline_state.node_state_update_context( node_uid) as node_state: if node_state.state == pstate.NodeState.STARTING: node_state.update(pstate.NodeState.STARTED) for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)