def run_generator_and_test(test_case, mlmd_connection, generator_class, pipeline, task_queue, use_task_queue, service_job_manager, num_initial_executions, num_tasks_generated, num_new_executions, num_active_executions, expected_exec_nodes=None, ignore_node_ids=None): """Runs generator.generate() and tests the effects.""" if service_job_manager is None: service_job_manager = service_jobs.DummyServiceJobManager() with mlmd_connection as m: executions = m.store.get_executions() test_case.assertLen( executions, num_initial_executions, f'Expected {num_initial_executions} execution(s) in MLMD.') pipeline_state = pstate.PipelineState.new(m, pipeline) generator_params = dict( mlmd_handle=m, pipeline_state=pipeline_state, is_task_id_tracked_fn=task_queue.contains_task_id, service_job_manager=service_job_manager) if generator_class == asptg.AsyncPipelineTaskGenerator: generator_params['ignore_node_ids'] = ignore_node_ids task_gen = generator_class(**generator_params) tasks = task_gen.generate() test_case.assertLen( tasks, num_tasks_generated, f'Expected {num_tasks_generated} task(s) to be generated.') executions = m.store.get_executions() num_total_executions = num_initial_executions + num_new_executions test_case.assertLen( executions, num_total_executions, f'Expected {num_total_executions} execution(s) in MLMD.') active_executions = [ e for e in executions if execution_lib.is_execution_active(e) ] test_case.assertLen( active_executions, num_active_executions, f'Expected {num_active_executions} active execution(s) in MLMD.') if expected_exec_nodes: for i, task in enumerate(tasks): _verify_exec_node_task(test_case, pipeline, expected_exec_nodes[i], active_executions[i].id, task) if use_task_queue: for task in tasks: if task_lib.is_exec_node_task(task): task_queue.enqueue(task) return tasks
def run_generator_and_test(test_case, mlmd_connection, generator_class, pipeline, task_queue, use_task_queue, service_job_manager, num_initial_executions, num_tasks_generated, num_new_executions, num_active_executions, expected_exec_nodes=None, ignore_update_node_state_tasks=False, fail_fast=None): """Runs generator.generate() and tests the effects.""" if service_job_manager is None: service_job_manager = service_jobs.DummyServiceJobManager() with mlmd_connection as m: executions = get_non_orchestrator_executions(m) test_case.assertLen( executions, num_initial_executions, f'Expected {num_initial_executions} execution(s) in MLMD.') tasks = run_generator( mlmd_connection, generator_class, pipeline, task_queue, use_task_queue, service_job_manager, ignore_update_node_state_tasks=ignore_update_node_state_tasks, fail_fast=fail_fast) with mlmd_connection as m: test_case.assertLen( tasks, num_tasks_generated, f'Expected {num_tasks_generated} task(s) to be generated.') executions = get_non_orchestrator_executions(m) num_total_executions = num_initial_executions + num_new_executions test_case.assertLen( executions, num_total_executions, f'Expected {num_total_executions} execution(s) in MLMD.') active_executions = [ e for e in executions if execution_lib.is_execution_active(e) ] test_case.assertLen( active_executions, num_active_executions, f'Expected {num_active_executions} active execution(s) in MLMD.') if expected_exec_nodes: for i, task in enumerate(t for t in tasks if task_lib.is_exec_node_task(t)): _verify_exec_node_task(test_case, pipeline, expected_exec_nodes[i], active_executions[i].id, task) return tasks
def test_no_tasks_generated_when_no_inputs(self, min_count): """Tests no tasks are generated when there are no inputs, regardless of min_count.""" for node in self._pipeline.nodes: for v in node.pipeline_node.inputs.inputs.values(): v.min_count = min_count with self._mlmd_connection as m: pipeline_state = test_utils.get_or_create_pipeline_state( m, self._pipeline) task_gen = asptg.AsyncPipelineTaskGenerator( m, lambda _: False, service_jobs.DummyServiceJobManager()) tasks = task_gen.generate(pipeline_state) self.assertEmpty(tasks, 'Expected no task generation when no inputs.') self.assertEmpty( test_utils.get_non_orchestrator_executions(m), 'There must not be any registered executions since no tasks were ' 'generated.')
def test_no_tasks_generated_when_no_inputs(self, min_count): """Tests no tasks are generated when there are no inputs, regardless of min_count.""" for node in self._pipeline.nodes: for v in node.pipeline_node.inputs.inputs.values(): v.min_count = min_count with self._mlmd_connection as m: pipeline_state = pstate.PipelineState(m, self._pipeline, 0) task_gen = asptg.AsyncPipelineTaskGenerator( m, pipeline_state, lambda _: False, service_jobs.DummyServiceJobManager(), ignore_node_ids=set([self._example_gen.node_info.id])) tasks = task_gen.generate() self.assertEmpty(tasks, 'Expected no task generation when no inputs.') self.assertEmpty( m.store.get_executions(), 'There must not be any registered executions since no tasks were ' 'generated.')
def test_handling_finalize_node_task(self, task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED, message='foo bar') task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Transform')), task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid( pipeline_uid=pipeline_uid, node_id='Trainer'), status=finalize_reason) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) task_gen.return_value.generate.assert_called_once() task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid) # Load pipeline state and verify node stop initiation. with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual( finalize_reason, pipeline_state.node_stop_initiated_reason( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer')))
def test_handling_finalize_pipeline_task(self, task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC) pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) finalize_reason = status_lib.Status( code=status_lib.Code.ABORTED, message='foo bar') task_gen.return_value.generate.side_effect = [ [ task_lib.FinalizePipelineTask( pipeline_uid=pipeline_uid, status=finalize_reason) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) task_gen.return_value.generate.assert_called_once() self.assertTrue(task_queue.is_empty()) # Load pipeline state and verify stop initiation. with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual(finalize_reason, pipeline_state.stop_initiated_reason())
def test_orchestrate_active_pipelines(self, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: # Sync and async active pipelines. async_pipelines = [ _test_pipeline('pipeline1'), _test_pipeline('pipeline2'), ] sync_pipelines = [ _test_pipeline('pipeline3', pipeline_pb2.Pipeline.SYNC), _test_pipeline('pipeline4', pipeline_pb2.Pipeline.SYNC), ] for pipeline in async_pipelines + sync_pipelines: pipeline_ops.initiate_pipeline_start(m, pipeline) # Active executions for active async pipelines. mock_async_task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( async_pipelines[0]), node_id='Transform')) ], [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( async_pipelines[1]), node_id='Trainer')) ], ] # Active executions for active sync pipelines. mock_sync_task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( sync_pipelines[0]), node_id='Trainer')) ], [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( sync_pipelines[1]), node_id='Validator')) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count) self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count) # Verify that tasks are enqueued in the expected order. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline2', 'Trainer'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline3', 'Trainer'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline4', 'Validator'), task.node_uid) self.assertTrue(task_queue.is_empty())
def setUp(self): super(TaskManagerE2ETest, self).setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'), pipeline) # Extracts components. self._example_gen = pipeline.nodes[0].pipeline_node self._transform = pipeline.nodes[1].pipeline_node self._trainer = pipeline.nodes[2].pipeline_node # Pack deployment config for testing. deployment_config = pipeline_pb2.IntermediateDeploymentConfig() executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec( class_path='fake.ClassPath') deployment_config.executor_specs[self._trainer.node_info.id].Pack( executor_spec) deployment_config.executor_specs[self._transform.node_info.id].Pack( executor_spec) self._type_url = deployment_config.executor_specs[ self._trainer.node_info.id].type_url pipeline.deployment_config.Pack(deployment_config) self._pipeline = pipeline self._pipeline_info = pipeline.pipeline_info self._pipeline_runtime_spec = pipeline.runtime_spec self._pipeline_runtime_spec.pipeline_root.field_value.string_value = ( pipeline_root) ts.TaskSchedulerRegistry.clear() self._task_queue = tq.TaskQueue() # Run fake example-gen to prepare downstreams component triggers. test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Task generator should produce a task to run transform. with self._mlmd_connection as m: pipeline_state = pstate.PipelineState(m, self._pipeline, 0) tasks = asptg.AsyncPipelineTaskGenerator( m, pipeline_state, self._task_queue.contains_task_id, service_jobs.DummyServiceJobManager()).generate() self.assertLen(tasks, 1) task = tasks[0] self.assertEqual('my_transform', task.node_uid.node_id) # Task generator should produce a task to run transform. with self._mlmd_connection as m: pipeline_state = pstate.PipelineState(m, self._pipeline, 0) tasks = asptg.AsyncPipelineTaskGenerator( m, pipeline_state, self._task_queue.contains_task_id, service_jobs.DummyServiceJobManager()).generate() self.assertLen(tasks, 1) self._task = tasks[0] self.assertEqual('my_transform', self._task.node_uid.node_id) self._task_queue.enqueue(self._task) # There should be 1 active execution in MLMD. with self._mlmd_connection as m: executions = m.store.get_executions() active_executions = [ e for e in executions if e.last_known_state == metadata_store_pb2.Execution.RUNNING ] self.assertLen(active_executions, 1) # Active execution id. self._execution_id = active_executions[0].id
def setUp(self): super().setUp() pipeline_root = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self.id()) # Makes sure multiple connections within a test always connect to the same # MLMD instance. metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db') self._metadata_path = metadata_path connection_config = metadata.sqlite_metadata_connection_config( metadata_path) connection_config.sqlite.SetInParent() self._mlmd_connection = metadata.Metadata( connection_config=connection_config) # Sets up the pipeline. pipeline = test_async_pipeline.create_pipeline() # Extracts components. self._example_gen = pipeline.nodes[0].pipeline_node self._transform = pipeline.nodes[1].pipeline_node self._trainer = pipeline.nodes[2].pipeline_node # Pack deployment config for testing. deployment_config = pipeline_pb2.IntermediateDeploymentConfig() executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec( class_path='fake.ClassPath') deployment_config.executor_specs[self._trainer.node_info.id].Pack( executor_spec) deployment_config.executor_specs[self._transform.node_info.id].Pack( executor_spec) self._type_url = deployment_config.executor_specs[ self._trainer.node_info.id].type_url pipeline.deployment_config.Pack(deployment_config) self._pipeline = pipeline self._pipeline_info = pipeline.pipeline_info self._pipeline_runtime_spec = pipeline.runtime_spec self._pipeline_runtime_spec.pipeline_root.field_value.string_value = ( pipeline_root) ts.TaskSchedulerRegistry.clear() self._task_queue = tq.TaskQueue() # Run fake example-gen to prepare downstreams component triggers. test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Task generator should produce two tasks for transform. The first one is # UpdateNodeStateTask and the second one is ExecNodeTask. with self._mlmd_connection as m: pipeline_state = pstate.PipelineState.new(m, self._pipeline) tasks = asptg.AsyncPipelineTaskGenerator( m, self._task_queue.contains_task_id, service_jobs.DummyServiceJobManager()).generate(pipeline_state) self.assertLen(tasks, 2) self.assertTrue(task_lib.is_update_node_state_task(tasks[0])) self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state) self.assertEqual('my_transform', tasks[0].node_uid.node_id) self.assertTrue(task_lib.is_exec_node_task(tasks[1])) self.assertEqual('my_transform', tasks[1].node_uid.node_id) self.assertTrue(os.path.exists(tasks[1].stateful_working_dir)) self.assertTrue(os.path.exists(tasks[1].tmp_dir)) self._task = tasks[1] self._output_artifact_uri = self._task.output_artifacts[ 'transform_graph'][0].uri self.assertTrue(os.path.exists(self._output_artifact_uri)) self._task_queue.enqueue(self._task) # There should be 1 active execution in MLMD. with self._mlmd_connection as m: executions = m.store.get_executions() active_executions = [ e for e in executions if e.last_known_state == metadata_store_pb2.Execution.RUNNING ] self.assertLen(active_executions, 1) # Active execution id. self._execution_id = active_executions[0].id