def setUp(self): super().setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) pipeline = self._make_pipeline(pipeline_root, str(uuid.uuid4())) self._pipeline = pipeline self._importer_node = self._pipeline.nodes[0].pipeline_node self._task_queue = tq.TaskQueue() [importer_task] = test_utils.run_generator_and_test( test_case=self, mlmd_connection=self._mlmd_connection, generator_class=sptg.SyncPipelineTaskGenerator, pipeline=self._pipeline, task_queue=self._task_queue, use_task_queue=True, service_job_manager=None, num_initial_executions=0, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._importer_node]) self._importer_task = importer_task
def setUp(self): super(AsyncPipelineTaskGeneratorTest, self).setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) self._pipeline_root = pipeline_root # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'), pipeline) self._pipeline = pipeline self._pipeline_info = pipeline.pipeline_info self._pipeline_runtime_spec = pipeline.runtime_spec self._pipeline_runtime_spec.pipeline_root.field_value.string_value = ( pipeline_root) # Extracts components. self._example_gen = pipeline.nodes[0].pipeline_node self._transform = pipeline.nodes[1].pipeline_node self._trainer = pipeline.nodes[2].pipeline_node self._task_queue = tq.TaskQueue() self._ignore_node_ids = set([self._example_gen.node_info.id])
def test_task_queue_operations(self): t1 = _test_task(node_id='trainer', pipeline_id='my_pipeline') t2 = _test_task(node_id='transform', pipeline_id='my_pipeline', pipeline_run_id='run_0') tq = task_queue.TaskQueue() # Enqueueing new tasks is successful. self.assertTrue(tq.enqueue(t1)) self.assertTrue(tq.enqueue(t2)) # Re-enqueueing the same tasks fails. self.assertFalse(tq.enqueue(t1)) self.assertFalse(tq.enqueue(t2)) # Dequeue succeeds and returns `None` when queue is empty. self.assertEqual(t1, tq.dequeue()) self.assertEqual(t2, tq.dequeue()) self.assertIsNone(tq.dequeue()) self.assertIsNone(tq.dequeue(0.1)) # Re-enqueueing the same tasks fails as `task_done` has not been called. self.assertFalse(tq.enqueue(t1)) self.assertFalse(tq.enqueue(t2)) tq.task_done(t1) tq.task_done(t2) # Re-enqueueing is allowed after `task_done` has been called. self.assertTrue(tq.enqueue(t1)) self.assertTrue(tq.enqueue(t2))
def setUp(self): super(AsyncPipelineTaskGeneratorTest, self).setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) self._pipeline_root = pipeline_root # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'), pipeline) self._pipeline = pipeline self._pipeline_info = pipeline.pipeline_info self._pipeline_runtime_spec = pipeline.runtime_spec self._pipeline_runtime_spec.pipeline_root.field_value.string_value = ( pipeline_root) # Extracts components. self._example_gen = pipeline.nodes[0].pipeline_node self._transform = pipeline.nodes[1].pipeline_node self._trainer = pipeline.nodes[2].pipeline_node self._task_queue = tq.TaskQueue() self._mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) def _is_pure_service_node(unused_pipeline_state, node_id): return node_id == self._example_gen.node_info.id def _is_mixed_service_node(unused_pipeline_state, node_id): return node_id == self._transform.node_info.id self._mock_service_job_manager.is_pure_service_node.side_effect = ( _is_pure_service_node) self._mock_service_job_manager.is_mixed_service_node.side_effect = ( _is_mixed_service_node) def _default_ensure_node_services(unused_pipeline_state, node_id): self.assertIn( node_id, (self._example_gen.node_info.id, self._transform.node_info.id)) return service_jobs.ServiceStatus.RUNNING self._mock_service_job_manager.ensure_node_services.side_effect = ( _default_ensure_node_services)
def test_queue_multiplexing(self, mock_publish): # Create a pipeline IR containing deployment config for testing. deployment_config = pipeline_pb2.IntermediateDeploymentConfig() executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec( class_path='trainer.TrainerExecutor') deployment_config.executor_specs['Trainer'].Pack(executor_spec) deployment_config.executor_specs['Transform'].Pack(executor_spec) deployment_config.executor_specs['Evaluator'].Pack(executor_spec) pipeline = pipeline_pb2.Pipeline() pipeline.deployment_config.Pack(deployment_config) collector = _Collector() # Register a bunch of fake task schedulers. # Register fake task scheduler. ts.TaskSchedulerRegistry.register( deployment_config.executor_specs['Trainer'].type_url, functools.partial(_FakeTaskScheduler, block_nodes={'Trainer', 'Transform'}, collector=collector)) task_queue = tq.TaskQueue() # Enqueue some tasks. trainer_exec_task = _test_exec_task('Trainer', 'test-pipeline') task_queue.enqueue(trainer_exec_task) task_queue.enqueue(_test_cancel_task('Trainer', 'test-pipeline')) with tm.TaskManager(mock.Mock(), pipeline, task_queue, max_active_task_schedulers=1000, max_dequeue_wait_secs=0.1, process_all_queued_tasks_before_exit=True): # Enqueue more tasks after task manager starts. transform_exec_task = _test_exec_task('Transform', 'test-pipeline') task_queue.enqueue(transform_exec_task) evaluator_exec_task = _test_exec_task('Evaluator', 'test-pipeline') task_queue.enqueue(evaluator_exec_task) task_queue.enqueue(_test_cancel_task('Transform', 'test-pipeline')) # Ensure that all exec and cancellation tasks were processed correctly. self.assertCountEqual( [trainer_exec_task, transform_exec_task, evaluator_exec_task], collector.scheduled_tasks) self.assertCountEqual([trainer_exec_task, transform_exec_task], collector.cancelled_tasks) mock_publish.assert_has_calls([ mock.call(mock.ANY, pipeline, trainer_exec_task, mock.ANY), mock.call(mock.ANY, pipeline, transform_exec_task, mock.ANY), mock.call(mock.ANY, pipeline, evaluator_exec_task, mock.ANY) ], any_order=True)
def test_task_handling(self, mock_publish): collector = _Collector() # Register a fake task scheduler. ts.TaskSchedulerRegistry.register( self._type_url, functools.partial( _FakeTaskScheduler, block_nodes={'Trainer', 'Transform'}, collector=collector)) task_queue = tq.TaskQueue() # Enqueue some tasks. trainer_exec_task = _test_exec_node_task('Trainer', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(trainer_exec_task) task_queue.enqueue(_test_cancel_node_task('Trainer', 'test-pipeline')) with self._task_manager(task_queue) as task_manager: # Enqueue more tasks after task manager starts. transform_exec_task = _test_exec_node_task('Transform', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(transform_exec_task) evaluator_exec_task = _test_exec_node_task('Evaluator', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(evaluator_exec_task) task_queue.enqueue(_test_cancel_node_task('Transform', 'test-pipeline')) self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Ensure that all exec and cancellation tasks were processed correctly. self.assertCountEqual( [trainer_exec_task, transform_exec_task, evaluator_exec_task], collector.scheduled_tasks) self.assertCountEqual([trainer_exec_task, transform_exec_task], collector.cancelled_tasks) mock_publish.assert_has_calls([ mock.call( mlmd_handle=mock.ANY, task=trainer_exec_task, result=mock.ANY), mock.call( mlmd_handle=mock.ANY, task=transform_exec_task, result=mock.ANY), mock.call( mlmd_handle=mock.ANY, task=evaluator_exec_task, result=mock.ANY), ], any_order=True)
def setUp(self): super().setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) self._pipeline_root = pipeline_root # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = self._make_pipeline(self._pipeline_root, str(uuid.uuid4())) self._pipeline = pipeline # Extracts components. self._example_gen = test_utils.get_node(pipeline, 'my_example_gen') self._stats_gen = test_utils.get_node(pipeline, 'my_statistics_gen') self._schema_gen = test_utils.get_node(pipeline, 'my_schema_gen') self._transform = test_utils.get_node(pipeline, 'my_transform') self._example_validator = test_utils.get_node(pipeline, 'my_example_validator') self._trainer = test_utils.get_node(pipeline, 'my_trainer') self._evaluator = test_utils.get_node(pipeline, 'my_evaluator') self._chore_a = test_utils.get_node(pipeline, 'chore_a') self._chore_b = test_utils.get_node(pipeline, 'chore_b') self._task_queue = tq.TaskQueue() self._mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) self._mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == self._example_gen.node_info.id) self._mock_service_job_manager.is_mixed_service_node.side_effect = ( lambda _, node_id: node_id == self._transform.node_info.id) def _default_ensure_node_services(unused_pipeline_state, node_id): self.assertIn( node_id, (self._example_gen.node_info.id, self._transform.node_info.id)) return service_jobs.ServiceStatus.SUCCESS self._mock_service_job_manager.ensure_node_services.side_effect = ( _default_ensure_node_services)
def test_exceptions_are_surfaced(self, mock_publish): def _publish(**kwargs): task = kwargs['task'] assert task_lib.is_exec_node_task(task) if task.node_uid.node_id == 'Transform': raise ValueError('test error') return mock.DEFAULT mock_publish.side_effect = _publish collector = _Collector() # Register a fake task scheduler. ts.TaskSchedulerRegistry.register( self._type_url, functools.partial( _FakeTaskScheduler, block_nodes={}, collector=collector)) task_queue = tq.TaskQueue() with self._task_manager(task_queue) as task_manager: transform_task = _test_exec_node_task('Transform', 'test-pipeline', pipeline=self._pipeline) trainer_task = _test_exec_node_task('Trainer', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(transform_task) task_queue.enqueue(trainer_task) self.assertTrue(task_manager.done()) exception = task_manager.exception() self.assertIsNotNone(exception) self.assertIsInstance(exception, tm.TasksProcessingError) self.assertLen(exception.errors, 1) self.assertEqual('test error', str(exception.errors[0])) self.assertCountEqual([transform_task, trainer_task], collector.scheduled_tasks) mock_publish.assert_has_calls([ mock.call( mlmd_handle=mock.ANY, task=transform_task, result=mock.ANY), mock.call( mlmd_handle=mock.ANY, task=trainer_task, result=mock.ANY), ], any_order=True)
def setUp(self): super(SyncPipelineTaskGeneratorTest, self).setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) self._pipeline_root = pipeline_root # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( os.path.dirname(__file__), 'testdata', 'sync_pipeline.pbtxt'), pipeline) self._pipeline_run_id = str(uuid.uuid4()) runtime_parameter_utils.substitute_runtime_parameter( pipeline, { 'pipeline_root': pipeline_root, 'pipeline_run_id': self._pipeline_run_id }) self._pipeline = pipeline # Extracts components. self._example_gen = _get_node(pipeline, 'my_example_gen') self._stats_gen = _get_node(pipeline, 'my_statistics_gen') self._schema_gen = _get_node(pipeline, 'my_schema_gen') self._transform = _get_node(pipeline, 'my_transform') self._example_validator = _get_node(pipeline, 'my_example_validator') self._trainer = _get_node(pipeline, 'my_trainer') self._task_queue = tq.TaskQueue() self._mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) self._mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == self._example_gen.node_info.id) self._mock_service_job_manager.is_mixed_service_node.side_effect = ( lambda _, node_id: node_id == self._transform.node_info.id)
def test_invalid_task_done_raises_errors(self): t1 = _test_task(node_id='trainer', pipeline_id='my_pipeline') t2 = _test_task(node_id='transform', pipeline_id='my_pipeline') tq = task_queue.TaskQueue() # Enqueue t1, but calling `task_done` raises error since t1 is not dequeued. self.assertTrue(tq.enqueue(t1)) with self.assertRaisesRegex(RuntimeError, 'Must call `dequeue`'): tq.task_done(t1) # `task_done` succeeds after dequeueing. self.assertEqual(t1, tq.dequeue()) tq.task_done(t1) # Error since t2 is not in the queue. with self.assertRaisesRegex(RuntimeError, 'Task not present'): tq.task_done(t2)
def setUp(self): super().setUp() # Set a constant version for artifact version tag. patcher = mock.patch('tfx.version.__version__') patcher.start() tfx_version.__version__ = '0.123.4.dev' self.addCleanup(patcher.stop) pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) pipeline = self._make_pipeline(pipeline_root, str(uuid.uuid4())) self._pipeline = pipeline self._importer_node = self._pipeline.nodes[0].pipeline_node self._task_queue = tq.TaskQueue() [importer_task] = test_utils.run_generator_and_test( test_case=self, mlmd_connection=self._mlmd_connection, generator_class=sptg.SyncPipelineTaskGenerator, pipeline=self._pipeline, task_queue=self._task_queue, use_task_queue=True, service_job_manager=None, num_initial_executions=0, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._importer_node], ignore_update_node_state_tasks=True) self._importer_task = importer_task
def test_handling_finalize_node_task(self, task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED, message='foo bar') task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Transform')), task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid( pipeline_uid=pipeline_uid, node_id='Trainer'), status=finalize_reason) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) task_gen.return_value.generate.assert_called_once() task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid) # Load pipeline state and verify node stop initiation. with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual( finalize_reason, pipeline_state.node_stop_initiated_reason( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer')))
def test_handling_finalize_pipeline_task(self, task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC) pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) finalize_reason = status_lib.Status( code=status_lib.Code.ABORTED, message='foo bar') task_gen.return_value.generate.side_effect = [ [ task_lib.FinalizePipelineTask( pipeline_uid=pipeline_uid, status=finalize_reason) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) task_gen.return_value.generate.assert_called_once() self.assertTrue(task_queue.is_empty()) # Load pipeline state and verify stop initiation. with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual(finalize_reason, pipeline_state.stop_initiated_reason())
def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == 'ExampleGen') mock_service_job_manager.is_mixed_service_node.side_effect = ( lambda _, node_id: node_id == 'Transform') pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: pipeline_state.initiate_stop( status_lib.Status(code=status_lib.Code.CANCELLED)) pipeline_execution = pipeline_state.execution task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), node_id='Transform'))) transform_task = task_queue.dequeue() # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task( node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # stop_node_services should be called for ExampleGen which is a pure # service node. mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') mock_service_job_manager.reset_mock() task_queue.task_done(transform_task) # Pop out transform task. # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask (with is_cancelled=True) for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task.is_cancelled) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call( m, pipeline_state.pipeline, pipeline.nodes[2].pipeline_node, mock.ANY, is_cancelled=True), mock.call( m, pipeline_state.pipeline, pipeline.nodes[3].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `orchestrate`. [execution] = m.store.get_executions_by_id([pipeline_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `orchestrate` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) self.assertTrue(task_queue.is_empty()) [execution] = m.store.get_executions_by_id([pipeline_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state) # stop_node_services should be called on both ExampleGen and Transform # which are service nodes. mock_service_job_manager.stop_node_services.assert_has_calls( [mock.call(mock.ANY, 'ExampleGen'), mock.call(mock.ANY, 'Transform')], any_order=True)
def test_orchestrate_active_pipelines(self, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: # Sync and async active pipelines. async_pipelines = [ _test_pipeline('pipeline1'), _test_pipeline('pipeline2'), ] sync_pipelines = [ _test_pipeline('pipeline3', pipeline_pb2.Pipeline.SYNC), _test_pipeline('pipeline4', pipeline_pb2.Pipeline.SYNC), ] for pipeline in async_pipelines + sync_pipelines: pipeline_ops.initiate_pipeline_start(m, pipeline) # Active executions for active async pipelines. mock_async_task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( async_pipelines[0]), node_id='Transform')) ], [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( async_pipelines[1]), node_id='Trainer')) ], ] # Active executions for active sync pipelines. mock_sync_task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( sync_pipelines[0]), node_id='Trainer')) ], [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( sync_pipelines[1]), node_id='Validator')) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count) self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count) # Verify that tasks are enqueued in the expected order. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline2', 'Trainer'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline3', 'Trainer'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline4', 'Validator'), task.node_uid) self.assertTrue(task_queue.is_empty())
def setUp(self): super(TaskManagerE2ETest, self).setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'), pipeline) # Extracts components. self._example_gen = pipeline.nodes[0].pipeline_node self._transform = pipeline.nodes[1].pipeline_node self._trainer = pipeline.nodes[2].pipeline_node # Pack deployment config for testing. deployment_config = pipeline_pb2.IntermediateDeploymentConfig() executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec( class_path='fake.ClassPath') deployment_config.executor_specs[self._trainer.node_info.id].Pack( executor_spec) deployment_config.executor_specs[self._transform.node_info.id].Pack( executor_spec) self._type_url = deployment_config.executor_specs[ self._trainer.node_info.id].type_url pipeline.deployment_config.Pack(deployment_config) self._pipeline = pipeline self._pipeline_info = pipeline.pipeline_info self._pipeline_runtime_spec = pipeline.runtime_spec self._pipeline_runtime_spec.pipeline_root.field_value.string_value = ( pipeline_root) ts.TaskSchedulerRegistry.clear() self._task_queue = tq.TaskQueue() # Run fake example-gen to prepare downstreams component triggers. test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Task generator should produce a task to run transform. with self._mlmd_connection as m: pipeline_state = pstate.PipelineState(m, self._pipeline, 0) tasks = asptg.AsyncPipelineTaskGenerator( m, pipeline_state, self._task_queue.contains_task_id, service_jobs.DummyServiceJobManager()).generate() self.assertLen(tasks, 1) task = tasks[0] self.assertEqual('my_transform', task.node_uid.node_id) # Task generator should produce a task to run transform. with self._mlmd_connection as m: pipeline_state = pstate.PipelineState(m, self._pipeline, 0) tasks = asptg.AsyncPipelineTaskGenerator( m, pipeline_state, self._task_queue.contains_task_id, service_jobs.DummyServiceJobManager()).generate() self.assertLen(tasks, 1) self._task = tasks[0] self.assertEqual('my_transform', self._task.node_uid.node_id) self._task_queue.enqueue(self._task) # There should be 1 active execution in MLMD. with self._mlmd_connection as m: executions = m.store.get_executions() active_executions = [ e for e in executions if e.last_known_state == metadata_store_pb2.Execution.RUNNING ] self.assertLen(active_executions, 1) # Active execution id. self._execution_id = active_executions[0].id
def test_stop_initiated_async_pipelines(self, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline1 = _test_pipeline('pipeline1') pipeline1.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline1.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline1.nodes.add().pipeline_node.node_info.id = 'Evaluator' pipeline_ops.initiate_pipeline_start(m, pipeline1) pipeline1_execution = pipeline_ops._initiate_pipeline_stop( m, task_lib.PipelineUid.from_pipeline(pipeline1)) task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Transform'))) transform_task = task_queue.dequeue( ) # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.generate_tasks(m, task_queue) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # Simulate finishing the "Transform" ExecNodeTask. task_queue.task_done(transform_task) # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call(m, pipeline1, pipeline1.nodes[1].pipeline_node, mock.ANY, is_cancelled=True), mock.call(m, pipeline1, pipeline1.nodes[2].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `generate_tasks`. [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `generate_tasks` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.generate_tasks(m, task_queue) self.assertTrue(task_queue.is_empty()) [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state)
def test_resolver_task_scheduler(self): with self._mlmd_connection as m: # Publishes two models which will be consumed by downstream resolver. output_model_1 = types.Artifact( self._trainer.outputs.outputs['model'].artifact_spec.type) output_model_1.uri = 'my_model_uri_1' output_model_2 = types.Artifact( self._trainer.outputs.outputs['model'].artifact_spec.type) output_model_2.uri = 'my_model_uri_2' contexts = context_lib.prepare_contexts(m, self._trainer.contexts) execution = execution_publish_utils.register_execution( m, self._trainer.node_info.type, contexts) execution_publish_utils.publish_succeeded_execution( m, execution.id, contexts, { 'model': [output_model_1, output_model_2], }) task_queue = tq.TaskQueue() # Verify that resolver task is generated. [resolver_task] = test_utils.run_generator_and_test( test_case=self, mlmd_connection=self._mlmd_connection, generator_class=sptg.SyncPipelineTaskGenerator, pipeline=self._pipeline, task_queue=task_queue, use_task_queue=False, service_job_manager=None, num_initial_executions=1, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._resolver_node], ignore_update_node_state_tasks=True) with self._mlmd_connection as m: # Run resolver task scheduler and publish results. ts_result = resolver_task_scheduler.ResolverTaskScheduler( mlmd_handle=m, pipeline=self._pipeline, task=resolver_task).schedule() self.assertEqual(status_lib.Code.OK, ts_result.status.code) self.assertIsInstance(ts_result.output, task_scheduler.ResolverNodeOutput) self.assertCountEqual( ['resolved_model'], ts_result.output.resolved_input_artifacts.keys()) models = ts_result.output.resolved_input_artifacts[ 'resolved_model'] self.assertLen(models, 1) self.assertEqual('my_model_uri_2', models[0].mlmd_artifact.uri) tm._publish_execution_results(m, resolver_task, ts_result) # Verify resolver node output is input to the downstream consumer node. [consumer_task] = test_utils.run_generator_and_test( test_case=self, mlmd_connection=self._mlmd_connection, generator_class=sptg.SyncPipelineTaskGenerator, pipeline=self._pipeline, task_queue=task_queue, use_task_queue=False, service_job_manager=None, num_initial_executions=2, num_tasks_generated=1, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._consumer_node], ignore_update_node_state_tasks=True) self.assertCountEqual(['resolved_model'], consumer_task.input_artifacts.keys()) input_models = consumer_task.input_artifacts['resolved_model'] self.assertLen(input_models, 1) self.assertEqual('my_model_uri_2', input_models[0].mlmd_artifact.uri)
def test_task_handling(self, mock_publish): collector = _Collector() # Register a fake task scheduler. ts.TaskSchedulerRegistry.register( self._type_url, functools.partial(_FakeTaskScheduler, block_nodes={'Trainer', 'Transform', 'Pusher'}, collector=collector)) task_queue = tq.TaskQueue() # Enqueue some tasks. trainer_exec_task = _test_exec_node_task('Trainer', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(trainer_exec_task) task_queue.enqueue(_test_cancel_node_task('Trainer', 'test-pipeline')) with self._task_manager(task_queue) as task_manager: # Enqueue more tasks after task manager starts. transform_exec_task = _test_exec_node_task('Transform', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(transform_exec_task) evaluator_exec_task = _test_exec_node_task('Evaluator', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(evaluator_exec_task) task_queue.enqueue( _test_cancel_node_task('Transform', 'test-pipeline')) pusher_exec_task = _test_exec_node_task('Pusher', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(pusher_exec_task) task_queue.enqueue( _test_cancel_node_task('Pusher', 'test-pipeline', pause=True)) self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Ensure that all exec and cancellation tasks were processed correctly. self.assertCountEqual([ trainer_exec_task, transform_exec_task, evaluator_exec_task, pusher_exec_task, ], collector.scheduled_tasks) self.assertCountEqual([ trainer_exec_task, transform_exec_task, pusher_exec_task, ], collector.cancelled_tasks) result_ok = ts.TaskSchedulerResult(status=status_lib.Status( code=status_lib.Code.OK, message='_FakeTaskScheduler result')) result_cancelled = ts.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.CANCELLED, message='_FakeTaskScheduler result')) mock_publish.assert_has_calls([ mock.call(mlmd_handle=mock.ANY, task=trainer_exec_task, result=result_cancelled), mock.call(mlmd_handle=mock.ANY, task=transform_exec_task, result=result_cancelled), mock.call(mlmd_handle=mock.ANY, task=evaluator_exec_task, result=result_ok), ], any_order=True) # It is expected that publish is not called for Pusher because it was # cancelled with pause=True so there must be only 3 calls. self.assertLen(mock_publish.mock_calls, 3)
def test_active_pipelines_with_stop_initiated_nodes( self, mock_gen_task_from_active, mock_async_task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline') pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' transform_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[0].pipeline_node) transform_task = test_utils.create_exec_node_task( node_uid=transform_node_uid) trainer_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[1].pipeline_node) trainer_task = test_utils.create_exec_node_task( node_uid=trainer_node_uid) evaluator_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[2].pipeline_node) evaluator_task = test_utils.create_exec_node_task( node_uid=evaluator_node_uid) cancelled_evaluator_task = test_utils.create_exec_node_task( node_uid=evaluator_node_uid, is_cancelled=True) pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: # Stop trainer and evaluator. pipeline_state.initiate_node_stop(trainer_node_uid) pipeline_state.initiate_node_stop(evaluator_node_uid) task_queue = tq.TaskQueue() # Simulate a new transform execution being triggered. mock_async_task_gen.return_value.generate.return_value = [ transform_task ] # Simulate ExecNodeTask for trainer already present in the task queue. task_queue.enqueue(trainer_task) # Simulate Evaluator having an active execution in MLMD. mock_gen_task_from_active.side_effect = [evaluator_task] pipeline_ops.orchestrate(m, task_queue) self.assertEqual( 1, mock_async_task_gen.return_value.generate.call_count) # Verify that tasks are enqueued in the expected order: # Pre-existing trainer task. task = task_queue.dequeue() task_queue.task_done(task) self.assertEqual(trainer_task, task) # CancelNodeTask for trainer. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual(trainer_node_uid, task.node_uid) # ExecNodeTask with is_cancelled=True for evaluator. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(cancelled_evaluator_task, task) # ExecNodeTask for newly triggered transform node. task = task_queue.dequeue() task_queue.task_done(task) self.assertEqual(transform_task, task) # No more tasks. self.assertTrue(task_queue.is_empty())
def test_active_pipelines_with_stop_initiated_nodes(self, mock_gen_task_from_active, mock_async_task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline') pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == 'ExampleGen') example_gen_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[0].pipeline_node) transform_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[1].pipeline_node) transform_task = test_utils.create_exec_node_task( node_uid=transform_node_uid) trainer_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[2].pipeline_node) trainer_task = test_utils.create_exec_node_task(node_uid=trainer_node_uid) evaluator_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[3].pipeline_node) evaluator_task = test_utils.create_exec_node_task( node_uid=evaluator_node_uid) cancelled_evaluator_task = test_utils.create_exec_node_task( node_uid=evaluator_node_uid, is_cancelled=True) pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: # Stop example-gen, trainer and evaluator. pipeline_state.initiate_node_stop( example_gen_node_uid, status_lib.Status(code=status_lib.Code.CANCELLED)) pipeline_state.initiate_node_stop( trainer_node_uid, status_lib.Status(code=status_lib.Code.CANCELLED)) pipeline_state.initiate_node_stop( evaluator_node_uid, status_lib.Status(code=status_lib.Code.ABORTED)) task_queue = tq.TaskQueue() # Simulate a new transform execution being triggered. mock_async_task_gen.return_value.generate.return_value = [transform_task] # Simulate ExecNodeTask for trainer already present in the task queue. task_queue.enqueue(trainer_task) # Simulate Evaluator having an active execution in MLMD. mock_gen_task_from_active.side_effect = [evaluator_task] pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) self.assertEqual(1, mock_async_task_gen.return_value.generate.call_count) # stop_node_services should be called on example-gen which is a pure # service node. mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') # Verify that tasks are enqueued in the expected order: # Pre-existing trainer task. task = task_queue.dequeue() task_queue.task_done(task) self.assertEqual(trainer_task, task) # CancelNodeTask for trainer. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual(trainer_node_uid, task.node_uid) # ExecNodeTask with is_cancelled=True for evaluator. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(cancelled_evaluator_task, task) # ExecNodeTask for newly triggered transform node. task = task_queue.dequeue() task_queue.task_done(task) self.assertEqual(transform_task, task) # No more tasks. self.assertTrue(task_queue.is_empty())
def test_generate_tasks_async_active_pipelines(self, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: # One active pipeline. pipeline1 = _test_pipeline('pipeline1') pipeline_ops.initiate_pipeline_start(m, pipeline1) # Another active pipeline (with previously completed execution). pipeline2 = _test_pipeline('pipeline2') execution2 = pipeline_ops.initiate_pipeline_start(m, pipeline2) execution2.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution2]) execution2 = pipeline_ops.initiate_pipeline_start(m, pipeline2) # Inactive pipelines should be ignored. pipeline3 = _test_pipeline('pipeline3') execution3 = pipeline_ops.initiate_pipeline_start(m, pipeline3) execution3.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution3]) # For active pipelines pipeline1 and pipeline2, there are a couple of # active executions. def _exec_node_tasks(): for pipeline_id in ('pipeline1', 'pipeline2'): yield [ test_utils.create_exec_node_task( node_uid=task_lib. NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id=pipeline_id, pipeline_run_id=None), node_id='Transform')), test_utils.create_exec_node_task( node_uid=task_lib. NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id=pipeline_id, pipeline_run_id=None), node_id='Trainer')) ] mock_async_task_gen.return_value.generate.side_effect = _exec_node_tasks( ) task_queue = tq.TaskQueue() pipeline_ops.generate_tasks(m, task_queue) self.assertEqual( 2, mock_async_task_gen.return_value.generate.call_count) mock_sync_task_gen.assert_not_called() # Verify that tasks are enqueued in the expected order. for node_id in ('Transform', 'Trainer'): task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual(node_id, task.node_uid.node_id) self.assertEqual('pipeline1', task.node_uid.pipeline_uid.pipeline_id) for node_id in ('Transform', 'Trainer'): task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual(node_id, task.node_uid.node_id) self.assertEqual('pipeline2', task.node_uid.pipeline_uid.pipeline_id) self.assertTrue(task_queue.is_empty())
def setUp(self): super().setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = test_async_pipeline.create_pipeline() # Extracts components. self._example_gen = pipeline.nodes[0].pipeline_node self._transform = pipeline.nodes[1].pipeline_node self._trainer = pipeline.nodes[2].pipeline_node # Pack deployment config for testing. deployment_config = pipeline_pb2.IntermediateDeploymentConfig() executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec( class_path='fake.ClassPath') deployment_config.executor_specs[self._trainer.node_info.id].Pack( executor_spec) deployment_config.executor_specs[self._transform.node_info.id].Pack( executor_spec) self._type_url = deployment_config.executor_specs[ self._trainer.node_info.id].type_url pipeline.deployment_config.Pack(deployment_config) self._pipeline = pipeline self._pipeline_info = pipeline.pipeline_info self._pipeline_runtime_spec = pipeline.runtime_spec self._pipeline_runtime_spec.pipeline_root.field_value.string_value = ( pipeline_root) ts.TaskSchedulerRegistry.clear() self._task_queue = tq.TaskQueue() # Run fake example-gen to prepare downstreams component triggers. test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Task generator should produce two tasks for transform. The first one is # UpdateNodeStateTask and the second one is ExecNodeTask. with self._mlmd_connection as m: pipeline_state = pstate.PipelineState.new(m, self._pipeline) tasks = asptg.AsyncPipelineTaskGenerator( m, self._task_queue.contains_task_id, service_jobs.DummyServiceJobManager()).generate(pipeline_state) self.assertLen(tasks, 2) self.assertTrue(task_lib.is_update_node_state_task(tasks[0])) self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state) self.assertEqual('my_transform', tasks[0].node_uid.node_id) self.assertTrue(task_lib.is_exec_node_task(tasks[1])) self.assertEqual('my_transform', tasks[1].node_uid.node_id) self.assertTrue(os.path.exists(tasks[1].stateful_working_dir)) self.assertTrue(os.path.exists(tasks[1].tmp_dir)) self._task = tasks[1] self._output_artifact_uri = self._task.output_artifacts[ 'transform_graph'][0].uri self.assertTrue(os.path.exists(self._output_artifact_uri)) self._task_queue.enqueue(self._task) # There should be 1 active execution in MLMD. with self._mlmd_connection as m: executions = m.store.get_executions() active_executions = [ e for e in executions if e.last_known_state == metadata_store_pb2.Execution.RUNNING ] self.assertLen(active_executions, 1) # Active execution id. self._execution_id = active_executions[0].id