def test_stop_pipeline_non_existent_or_inactive(self, pipeline): with self._mlmd_connection as m: # Stop pipeline without creating one. with self.assertRaises( status_lib.StatusNotOkError) as exception_context: pipeline_ops.stop_pipeline( m, task_lib.PipelineUid.from_pipeline(pipeline)) self.assertEqual(status_lib.Code.NOT_FOUND, exception_context.exception.code) # Initiate pipeline start and mark it completed. pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: pipeline_state.initiate_stop( status_lib.Status(code=status_lib.Code.OK)) pipeline_state.execution.last_known_state = ( metadata_store_pb2.Execution.COMPLETE) # Try to initiate stop again. with self.assertRaises( status_lib.StatusNotOkError) as exception_context: pipeline_ops.stop_pipeline(m, pipeline_uid) self.assertEqual(status_lib.Code.NOT_FOUND, exception_context.exception.code)
def test_initiate_pipeline_start(self, pipeline): with self._mlmd_connection as m: # Initiate a pipeline start. pipeline_state1 = pipeline_ops.initiate_pipeline_start(m, pipeline) self.assertProtoPartiallyEquals( pipeline, pipeline_state1.pipeline, ignored_fields=['runtime_spec']) self.assertEqual(metadata_store_pb2.Execution.NEW, pipeline_state1.execution.last_known_state) # Initiate another pipeline start. pipeline2 = _test_pipeline('pipeline2') pipeline_state2 = pipeline_ops.initiate_pipeline_start(m, pipeline2) self.assertEqual(pipeline2, pipeline_state2.pipeline) self.assertEqual(metadata_store_pb2.Execution.NEW, pipeline_state2.execution.last_known_state) # Error if attempted to initiate when old one is active. with self.assertRaises(status_lib.StatusNotOkError) as exception_context: pipeline_ops.initiate_pipeline_start(m, pipeline) self.assertEqual(status_lib.Code.ALREADY_EXISTS, exception_context.exception.code) # Fine to initiate after the previous one is inactive. execution = pipeline_state1.execution execution.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution]) pipeline_state3 = pipeline_ops.initiate_pipeline_start(m, pipeline) self.assertEqual(metadata_store_pb2.Execution.NEW, pipeline_state3.execution.last_known_state)
def test_initiate_pipeline_stop(self): with self._mlmd_connection as m: pipeline1 = _test_pipeline('pipeline1') pipeline_ops.initiate_pipeline_start(m, pipeline1) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline1) pipeline_state = pipeline_ops._initiate_pipeline_stop( m, pipeline_uid) self.assertTrue(pipeline_state.is_stop_initiated())
def test_stop_pipeline_wait_for_inactivation_timeout(self, pipeline): with self._mlmd_connection as m: pipeline_ops.initiate_pipeline_start(m, pipeline) with self.assertRaisesRegex( status_lib.StatusNotOkError, 'Timed out.*waiting for execution inactivation.' ) as exception_context: pipeline_ops.stop_pipeline( m, task_lib.PipelineUid.from_pipeline(pipeline), timeout_secs=1.0) self.assertEqual(status_lib.Code.DEADLINE_EXCEEDED, exception_context.exception.code)
def test_initiate_pipeline_stop(self): with self._mlmd_connection as m: pipeline1 = _test_pipeline('pipeline1') pipeline_ops.initiate_pipeline_start(m, pipeline1) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline1) pipeline_ops._initiate_pipeline_stop(m, pipeline_uid) # Verify MLMD state. executions = m.store.get_executions_by_type( pipeline_ops._ORCHESTRATOR_RESERVED_ID) self.assertLen(executions, 1) execution = executions[0] self.assertEqual( 1, execution.custom_properties[ pipeline_ops._STOP_INITIATED].int_value)
def create_sample_pipeline(m: metadata.Metadata, pipeline_id: str, run_num: int, export_ir_path: str = '', external_ir_file: str = '', deployment_config: Optional[message.Message] = None, execute_nodes_func: Callable[ [metadata.Metadata, pipeline_pb2.Pipeline, int], None] = _execute_nodes): """Creates a list of pipeline and node execution.""" ir_path = _get_ir_path(external_ir_file) for i in range(run_num): run_id = 'run%02d' % i pipeline = _test_pipeline(ir_path, pipeline_id, run_id, deployment_config) if export_ir_path: output_path = os.path.join(export_ir_path, '%s_%s.pbtxt' % (pipeline_id, run_id)) io_utils.write_pbtxt_file(output_path, pipeline) pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline) if not external_ir_file: execute_nodes_func(m, pipeline, i) if i < run_num - 1: with pipeline_state: pipeline_state.set_pipeline_execution_state( metadata_store_pb2.Execution.COMPLETE)
def test_stop_pipeline_non_existent(self): with self._mlmd_connection as m: # Stop pipeline without creating one. with self.assertRaises( status_lib.StatusNotOkError) as exception_context: pipeline_ops.stop_pipeline( m, task_lib.PipelineUid(pipeline_id='foo', pipeline_run_id=None)) self.assertEqual(status_lib.Code.NOT_FOUND, exception_context.exception.code) # Initiate pipeline start and mark it completed. pipeline1 = _test_pipeline('pipeline1') execution = pipeline_ops.initiate_pipeline_start(m, pipeline1) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline1) pipeline_ops._initiate_pipeline_stop(m, pipeline_uid) execution.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution]) # Try to initiate stop again. with self.assertRaises( status_lib.StatusNotOkError) as exception_context: pipeline_ops.stop_pipeline(m, pipeline_uid) self.assertEqual(status_lib.Code.ALREADY_EXISTS, exception_context.exception.code)
def test_save_and_remove_pipeline_property(self): with self._mlmd_connection as m: pipeline1 = _test_pipeline('pipeline1') pipeline_state1 = pipeline_ops.initiate_pipeline_start(m, pipeline1) property_key = 'test_key' property_value = 'bala' self.assertIsNone( pipeline_state1.execution.custom_properties.get(property_key)) pipeline_ops.save_pipeline_property(pipeline_state1.mlmd_handle, pipeline_state1.pipeline_uid, property_key, property_value) with pstate.PipelineState.load( m, pipeline_state1.pipeline_uid) as pipeline_state2: self.assertIsNotNone( pipeline_state2.execution.custom_properties.get(property_key)) self.assertEqual( pipeline_state2.execution.custom_properties[property_key] .string_value, property_value) pipeline_ops.remove_pipeline_property(pipeline_state2.mlmd_handle, pipeline_state2.pipeline_uid, property_key) with pstate.PipelineState.load( m, pipeline_state2.pipeline_uid) as pipeline_state3: self.assertIsNone( pipeline_state3.execution.custom_properties.get(property_key))
def test_sync_pipeline_run_id_runtime_parameter(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC) pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline) self.assertNotEmpty(pipeline_state.pipeline.runtime_spec. pipeline_run_id.field_value.string_value) self.assertEqual(task_lib.PipelineUid(pipeline_id='pipeline1'), pipeline_state.pipeline_uid)
def create_sample_pipeline(m: metadata.Metadata, pipeline_id: str, run_num: int): """Creates a list of pipeline and node execution.""" for i in range(run_num): run_id = 'run%02d' % i pipeline = _test_pipeline(pipeline_id, run_id) pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline) _execute_nodes(m, pipeline, i) if i < run_num - 1: execution = pipeline_state.execution execution.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution])
def test_initiate_pipeline_start(self): with self._mlmd_connection as m: # Initiate a pipeline start. pipeline1 = _test_pipeline('pipeline1') pipeline_ops.initiate_pipeline_start(m, pipeline1) # Initiate another pipeline start. pipeline2 = _test_pipeline('pipeline2') pipeline_ops.initiate_pipeline_start(m, pipeline2) # No error raised => context/execution types exist. m.store.get_context_type(pipeline_ops._ORCHESTRATOR_RESERVED_ID) m.store.get_execution_type(pipeline_ops._ORCHESTRATOR_RESERVED_ID) # Verify MLMD state. contexts = m.store.get_contexts_by_type( pipeline_ops._ORCHESTRATOR_RESERVED_ID) self.assertLen(contexts, 2) self.assertCountEqual([ pipeline_ops._orchestrator_context_name( task_lib.PipelineUid.from_pipeline(pipeline1)), pipeline_ops._orchestrator_context_name( task_lib.PipelineUid.from_pipeline(pipeline2)) ], [c.name for c in contexts]) for context in contexts: executions = m.store.get_executions_by_context(context.id) self.assertLen(executions, 1) self.assertEqual(metadata_store_pb2.Execution.NEW, executions[0].last_known_state) retrieved_pipeline = pipeline_pb2.Pipeline() retrieved_pipeline.ParseFromString( base64.b64decode(executions[0].properties[ pipeline_ops._PIPELINE_IR].string_value)) expected_pipeline_id = ( pipeline_ops._pipeline_uid_from_context( context).pipeline_id) self.assertEqual(_test_pipeline(expected_pipeline_id), retrieved_pipeline)
def test_handling_finalize_node_task(self, task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED, message='foo bar') task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Transform')), task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid( pipeline_uid=pipeline_uid, node_id='Trainer'), status=finalize_reason) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) task_gen.return_value.generate.assert_called_once() task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid) # Load pipeline state and verify node stop initiation. with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual( finalize_reason, pipeline_state.node_stop_initiated_reason( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer')))
def test_initiate_pipeline_start_new_execution(self): with self._mlmd_connection as m: pipeline1 = _test_pipeline('pipeline1') pipeline_ops.initiate_pipeline_start(m, pipeline1) # Error if attempted to initiate when old one is active. with self.assertRaises( status_lib.StatusNotOkError) as exception_context: pipeline_ops.initiate_pipeline_start(m, pipeline1) self.assertEqual(status_lib.Code.ALREADY_EXISTS, exception_context.exception.code) # Fine to initiate after the previous one is inactive. executions = m.store.get_executions_by_type( pipeline_ops._ORCHESTRATOR_RESERVED_ID) self.assertLen(executions, 1) executions[ 0].last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions(executions) execution = pipeline_ops.initiate_pipeline_start(m, pipeline1) self.assertEqual(metadata_store_pb2.Execution.NEW, execution.last_known_state) # Verify MLMD state. contexts = m.store.get_contexts_by_type( pipeline_ops._ORCHESTRATOR_RESERVED_ID) self.assertLen(contexts, 1) self.assertEqual( pipeline_ops._orchestrator_context_name( task_lib.PipelineUid.from_pipeline(pipeline1)), contexts[0].name) executions = m.store.get_executions_by_context(contexts[0].id) self.assertLen(executions, 2) self.assertCountEqual([ metadata_store_pb2.Execution.COMPLETE, metadata_store_pb2.Execution.NEW ], [e.last_known_state for e in executions])
def test_handling_finalize_pipeline_task(self, task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC) pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) finalize_reason = status_lib.Status( code=status_lib.Code.ABORTED, message='foo bar') task_gen.return_value.generate.side_effect = [ [ task_lib.FinalizePipelineTask( pipeline_uid=pipeline_uid, status=finalize_reason) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) task_gen.return_value.generate.assert_called_once() self.assertTrue(task_queue.is_empty()) # Load pipeline state and verify stop initiation. with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual(finalize_reason, pipeline_state.stop_initiated_reason())
def create_sample_pipeline(m: metadata.Metadata, pipeline_id: str, run_num: int, export_ir_path: str = ''): """Creates a list of pipeline and node execution.""" for i in range(run_num): run_id = 'run%02d' % i pipeline = _test_pipeline(pipeline_id, run_id) if export_ir_path: output_path = os.path.join(export_ir_path, '%s_%s.pbtxt' % (pipeline_id, run_id)) io_utils.write_pbtxt_file(output_path, pipeline) pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline) _execute_nodes(m, pipeline, i) if i < run_num - 1: with pipeline_state: pipeline_state.execution.last_known_state = ( metadata_store_pb2.Execution.COMPLETE)
def test_stop_pipeline_wait_for_inactivation(self, pipeline): with self._mlmd_connection as m: execution = pipeline_ops.initiate_pipeline_start(m, pipeline).execution def _inactivate(execution): time.sleep(2.0) with pipeline_ops._PIPELINE_OPS_LOCK: execution.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution]) thread = threading.Thread( target=_inactivate, args=(copy.deepcopy(execution),)) thread.start() pipeline_ops.stop_pipeline( m, task_lib.PipelineUid.from_pipeline(pipeline), timeout_secs=10.0) thread.join()
def test_stop_pipeline_wait_for_inactivation(self, pipeline): with self._mlmd_connection as m: pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline) def _inactivate(pipeline_state): time.sleep(2.0) with pipeline_ops._PIPELINE_OPS_LOCK: with pipeline_state: pipeline_state.set_pipeline_execution_state( metadata_store_pb2.Execution.COMPLETE) thread = threading.Thread(target=_inactivate, args=(pipeline_state, )) thread.start() pipeline_ops.stop_pipeline( m, task_lib.PipelineUid.from_pipeline(pipeline), timeout_secs=10.0) thread.join()
def test_active_pipelines_with_stop_initiated_nodes(self, mock_gen_task_from_active, mock_async_task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline') pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == 'ExampleGen') example_gen_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[0].pipeline_node) transform_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[1].pipeline_node) transform_task = test_utils.create_exec_node_task( node_uid=transform_node_uid) trainer_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[2].pipeline_node) trainer_task = test_utils.create_exec_node_task(node_uid=trainer_node_uid) evaluator_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[3].pipeline_node) evaluator_task = test_utils.create_exec_node_task( node_uid=evaluator_node_uid) cancelled_evaluator_task = test_utils.create_exec_node_task( node_uid=evaluator_node_uid, is_cancelled=True) pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: # Stop example-gen, trainer and evaluator. pipeline_state.initiate_node_stop( example_gen_node_uid, status_lib.Status(code=status_lib.Code.CANCELLED)) pipeline_state.initiate_node_stop( trainer_node_uid, status_lib.Status(code=status_lib.Code.CANCELLED)) pipeline_state.initiate_node_stop( evaluator_node_uid, status_lib.Status(code=status_lib.Code.ABORTED)) task_queue = tq.TaskQueue() # Simulate a new transform execution being triggered. mock_async_task_gen.return_value.generate.return_value = [transform_task] # Simulate ExecNodeTask for trainer already present in the task queue. task_queue.enqueue(trainer_task) # Simulate Evaluator having an active execution in MLMD. mock_gen_task_from_active.side_effect = [evaluator_task] pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) self.assertEqual(1, mock_async_task_gen.return_value.generate.call_count) # stop_node_services should be called on example-gen which is a pure # service node. mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') # Verify that tasks are enqueued in the expected order: # Pre-existing trainer task. task = task_queue.dequeue() task_queue.task_done(task) self.assertEqual(trainer_task, task) # CancelNodeTask for trainer. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual(trainer_node_uid, task.node_uid) # ExecNodeTask with is_cancelled=True for evaluator. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(cancelled_evaluator_task, task) # ExecNodeTask for newly triggered transform node. task = task_queue.dequeue() task_queue.task_done(task) self.assertEqual(transform_task, task) # No more tasks. self.assertTrue(task_queue.is_empty())
def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == 'ExampleGen') mock_service_job_manager.is_mixed_service_node.side_effect = ( lambda _, node_id: node_id == 'Transform') pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: pipeline_state.initiate_stop( status_lib.Status(code=status_lib.Code.CANCELLED)) pipeline_execution = pipeline_state.execution task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), node_id='Transform'))) transform_task = task_queue.dequeue() # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task( node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # stop_node_services should be called for ExampleGen which is a pure # service node. mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') mock_service_job_manager.reset_mock() task_queue.task_done(transform_task) # Pop out transform task. # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask (with is_cancelled=True) for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task.is_cancelled) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call( m, pipeline_state.pipeline, pipeline.nodes[2].pipeline_node, mock.ANY, is_cancelled=True), mock.call( m, pipeline_state.pipeline, pipeline.nodes[3].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `orchestrate`. [execution] = m.store.get_executions_by_id([pipeline_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `orchestrate` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) self.assertTrue(task_queue.is_empty()) [execution] = m.store.get_executions_by_id([pipeline_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state) # stop_node_services should be called on both ExampleGen and Transform # which are service nodes. mock_service_job_manager.stop_node_services.assert_has_calls( [mock.call(mock.ANY, 'ExampleGen'), mock.call(mock.ANY, 'Transform')], any_order=True)
def test_orchestrate_active_pipelines(self, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: # Sync and async active pipelines. async_pipelines = [ _test_pipeline('pipeline1'), _test_pipeline('pipeline2'), ] sync_pipelines = [ _test_pipeline('pipeline3', pipeline_pb2.Pipeline.SYNC), _test_pipeline('pipeline4', pipeline_pb2.Pipeline.SYNC), ] for pipeline in async_pipelines + sync_pipelines: pipeline_ops.initiate_pipeline_start(m, pipeline) # Active executions for active async pipelines. mock_async_task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( async_pipelines[0]), node_id='Transform')) ], [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( async_pipelines[1]), node_id='Trainer')) ], ] # Active executions for active sync pipelines. mock_sync_task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( sync_pipelines[0]), node_id='Trainer')) ], [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( sync_pipelines[1]), node_id='Validator')) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count) self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count) # Verify that tasks are enqueued in the expected order. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline2', 'Trainer'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline3', 'Trainer'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline4', 'Validator'), task.node_uid) self.assertTrue(task_queue.is_empty())
def test_stop_initiated_async_pipelines(self, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline1 = _test_pipeline('pipeline1') pipeline1.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline1.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline1.nodes.add().pipeline_node.node_info.id = 'Evaluator' pipeline_ops.initiate_pipeline_start(m, pipeline1) pipeline1_execution = pipeline_ops._initiate_pipeline_stop( m, task_lib.PipelineUid.from_pipeline(pipeline1)) task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Transform'))) transform_task = task_queue.dequeue( ) # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.generate_tasks(m, task_queue) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # Simulate finishing the "Transform" ExecNodeTask. task_queue.task_done(transform_task) # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call(m, pipeline1, pipeline1.nodes[1].pipeline_node, mock.ANY, is_cancelled=True), mock.call(m, pipeline1, pipeline1.nodes[2].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `generate_tasks`. [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `generate_tasks` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.generate_tasks(m, task_queue) self.assertTrue(task_queue.is_empty()) [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state)
def test_generate_tasks_async_active_pipelines(self, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: # One active pipeline. pipeline1 = _test_pipeline('pipeline1') pipeline_ops.initiate_pipeline_start(m, pipeline1) # Another active pipeline (with previously completed execution). pipeline2 = _test_pipeline('pipeline2') execution2 = pipeline_ops.initiate_pipeline_start(m, pipeline2) execution2.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution2]) execution2 = pipeline_ops.initiate_pipeline_start(m, pipeline2) # Inactive pipelines should be ignored. pipeline3 = _test_pipeline('pipeline3') execution3 = pipeline_ops.initiate_pipeline_start(m, pipeline3) execution3.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution3]) # For active pipelines pipeline1 and pipeline2, there are a couple of # active executions. def _exec_node_tasks(): for pipeline_id in ('pipeline1', 'pipeline2'): yield [ test_utils.create_exec_node_task( node_uid=task_lib. NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id=pipeline_id, pipeline_run_id=None), node_id='Transform')), test_utils.create_exec_node_task( node_uid=task_lib. NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id=pipeline_id, pipeline_run_id=None), node_id='Trainer')) ] mock_async_task_gen.return_value.generate.side_effect = _exec_node_tasks( ) task_queue = tq.TaskQueue() pipeline_ops.generate_tasks(m, task_queue) self.assertEqual( 2, mock_async_task_gen.return_value.generate.call_count) mock_sync_task_gen.assert_not_called() # Verify that tasks are enqueued in the expected order. for node_id in ('Transform', 'Trainer'): task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual(node_id, task.node_uid.node_id) self.assertEqual('pipeline1', task.node_uid.pipeline_uid.pipeline_id) for node_id in ('Transform', 'Trainer'): task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual(node_id, task.node_uid.node_id) self.assertEqual('pipeline2', task.node_uid.pipeline_uid.pipeline_id) self.assertTrue(task_queue.is_empty())
def test_active_pipelines_with_stop_initiated_nodes( self, mock_gen_task_from_active, mock_async_task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline') pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' transform_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[0].pipeline_node) transform_task = test_utils.create_exec_node_task( node_uid=transform_node_uid) trainer_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[1].pipeline_node) trainer_task = test_utils.create_exec_node_task( node_uid=trainer_node_uid) evaluator_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[2].pipeline_node) evaluator_task = test_utils.create_exec_node_task( node_uid=evaluator_node_uid) cancelled_evaluator_task = test_utils.create_exec_node_task( node_uid=evaluator_node_uid, is_cancelled=True) pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: # Stop trainer and evaluator. pipeline_state.initiate_node_stop(trainer_node_uid) pipeline_state.initiate_node_stop(evaluator_node_uid) task_queue = tq.TaskQueue() # Simulate a new transform execution being triggered. mock_async_task_gen.return_value.generate.return_value = [ transform_task ] # Simulate ExecNodeTask for trainer already present in the task queue. task_queue.enqueue(trainer_task) # Simulate Evaluator having an active execution in MLMD. mock_gen_task_from_active.side_effect = [evaluator_task] pipeline_ops.orchestrate(m, task_queue) self.assertEqual( 1, mock_async_task_gen.return_value.generate.call_count) # Verify that tasks are enqueued in the expected order: # Pre-existing trainer task. task = task_queue.dequeue() task_queue.task_done(task) self.assertEqual(trainer_task, task) # CancelNodeTask for trainer. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual(trainer_node_uid, task.node_uid) # ExecNodeTask with is_cancelled=True for evaluator. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(cancelled_evaluator_task, task) # ExecNodeTask for newly triggered transform node. task = task_queue.dequeue() task_queue.task_done(task) self.assertEqual(transform_task, task) # No more tasks. self.assertTrue(task_queue.is_empty())