def _exec_node_tasks(): for pipeline_id in ('pipeline1', 'pipeline2'): yield [ test_utils.create_exec_node_task( node_uid=task_lib. NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id=pipeline_id, pipeline_run_id=None), node_id='Transform')), test_utils.create_exec_node_task( node_uid=task_lib. NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id=pipeline_id, pipeline_run_id=None), node_id='Trainer')) ]
def test_stop_pipeline_non_existent(self): with self._mlmd_connection as m: # Stop pipeline without creating one. with self.assertRaises( status_lib.StatusNotOkError) as exception_context: pipeline_ops.stop_pipeline( m, task_lib.PipelineUid(pipeline_id='foo', pipeline_run_id=None)) self.assertEqual(status_lib.Code.NOT_FOUND, exception_context.exception.code) # Initiate pipeline start and mark it completed. pipeline1 = _test_pipeline('pipeline1') execution = pipeline_ops.initiate_pipeline_start(m, pipeline1) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline1) pipeline_ops._initiate_pipeline_stop(m, pipeline_uid) execution.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution]) # Try to initiate stop again. with self.assertRaises( status_lib.StatusNotOkError) as exception_context: pipeline_ops.stop_pipeline(m, pipeline_uid) self.assertEqual(status_lib.Code.ALREADY_EXISTS, exception_context.exception.code)
def test_task_ids(self): pipeline_uid = task_lib.PipelineUid(pipeline_id='pipeline') node_uid = task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer') exec_node_task = test_utils.create_exec_node_task(node_uid) self.assertEqual(('ExecNodeTask', node_uid), exec_node_task.task_id) cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid) self.assertEqual(('CancelNodeTask', node_uid), cancel_node_task.task_id)
def pipeline_uid_from_orchestrator_context( context: metadata_store_pb2.Context) -> task_lib.PipelineUid: """Returns pipeline uid from orchestrator reserved context.""" splits = context.name.split(':') pipeline_id = splits[0] key = splits[1] if len(splits) > 1 else '' return task_lib.PipelineUid(pipeline_id=pipeline_id, key=key)
def test_sync_pipeline_run_id_runtime_parameter(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC) pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline) self.assertNotEmpty(pipeline_state.pipeline.runtime_spec. pipeline_run_id.field_value.string_value) self.assertEqual(task_lib.PipelineUid(pipeline_id='pipeline1'), pipeline_state.pipeline_uid)
def test_scheduler_not_found(self): task = test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'), node_id='Transform'), pipeline=self._pipeline) with self.assertRaisesRegex(ValueError, 'No task scheduler found'): ts.TaskSchedulerRegistry.create_task_scheduler( mock.Mock(), self._pipeline, task)
def _test_exec_node_task(node_id, pipeline_id, pipeline_run_id=None, pipeline=None): node_uid = task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id=pipeline_id, pipeline_run_id=pipeline_run_id), node_id=node_id) return test_utils.create_exec_node_task(node_uid, pipeline=pipeline)
def test_node_uid_from_pipeline_node(self): pipeline = pipeline_pb2.Pipeline() pipeline.pipeline_info.id = 'pipeline' node = pipeline_pb2.PipelineNode() node.node_info.id = 'Trainer' self.assertEqual( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'), node_id='Trainer'), task_lib.NodeUid.from_pipeline_node(pipeline, node))
def test_node_uid_from_pipeline_node(self): pipeline = pipeline_pb2.Pipeline() pipeline.pipeline_info.id = 'pipeline' pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run0' node = pipeline_pb2.PipelineNode() node.node_info.id = 'Trainer' self.assertEqual( task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id='pipeline', pipeline_run_id='run0'), node_id='Trainer'), task_lib.NodeUid.from_pipeline_node(pipeline, node))
def test_register_using_executor_spec_type_url(self): # Register a fake task scheduler. ts.TaskSchedulerRegistry.register(self._spec_type_url, _FakeTaskScheduler) # Create a task and verify that the correct scheduler is instantiated. task = test_utils.create_exec_node_task( node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'), node_id='Trainer'), pipeline=self._pipeline) task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler( mock.Mock(), self._pipeline, task) self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
def test_register_using_node_type_name(self): # Register a fake task scheduler. ts.TaskSchedulerRegistry.register(constants.IMPORTER_NODE_TYPE, _FakeTaskScheduler) # Create a task and verify that the correct scheduler is instantiated. task = test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'), node_id='Importer'), pipeline=self._pipeline) task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler( mock.Mock(), self._pipeline, task) self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
def test_registration_and_creation(self): # Create a pipeline IR containing deployment config for testing. deployment_config = pipeline_pb2.IntermediateDeploymentConfig() executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec( class_path='trainer.TrainerExecutor') deployment_config.executor_specs['Trainer'].Pack(executor_spec) pipeline = pipeline_pb2.Pipeline() pipeline.deployment_config.Pack(deployment_config) # Register a fake task scheduler. spec_type_url = deployment_config.executor_specs['Trainer'].type_url ts.TaskSchedulerRegistry.register(spec_type_url, _FakeTaskScheduler) # Create a task and verify that the correct scheduler is instantiated. task = test_utils.create_exec_node_task( node_uid=task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id='pipeline', pipeline_run_id=None), node_id='Trainer')) task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler( mock.Mock(), pipeline, task) self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
def test_stop_initiated_async_pipelines(self, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline1 = _test_pipeline('pipeline1') pipeline1.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline1.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline1.nodes.add().pipeline_node.node_info.id = 'Evaluator' pipeline_ops.initiate_pipeline_start(m, pipeline1) pipeline1_execution = pipeline_ops._initiate_pipeline_stop( m, task_lib.PipelineUid.from_pipeline(pipeline1)) task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Transform'))) transform_task = task_queue.dequeue( ) # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.generate_tasks(m, task_queue) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # Simulate finishing the "Transform" ExecNodeTask. task_queue.task_done(transform_task) # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call(m, pipeline1, pipeline1.nodes[1].pipeline_node, mock.ANY, is_cancelled=True), mock.call(m, pipeline1, pipeline1.nodes[2].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `generate_tasks`. [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `generate_tasks` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.generate_tasks(m, task_queue) self.assertTrue(task_queue.is_empty()) [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state)
def _pipeline_uid_from_context( context: metadata_store_pb2.Context) -> task_lib.PipelineUid: """Returns pipeline uid from orchestrator reserved context.""" pipeline_id = context.name.split(_ORCHESTRATOR_RESERVED_ID + '_')[1] return task_lib.PipelineUid(pipeline_id=pipeline_id, pipeline_run_id=None)
def pipeline_uid_from_orchestrator_context( context: metadata_store_pb2.Context) -> task_lib.PipelineUid: """Returns pipeline uid from orchestrator reserved context.""" return task_lib.PipelineUid(context.name)
def create_node_uid(pipeline_id, node_id): """Creates node uid.""" return task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id), node_id=node_id)
def _test_cancel_node_task(node_id, pipeline_id): node_uid = task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id), node_id=node_id) return task_lib.CancelNodeTask(node_uid=node_uid)
def _test_task(node_id, pipeline_id, key=''): node_uid = task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id, key=key), node_id=node_id) return test_utils.create_exec_node_task(node_uid)