def test_node_state_for_skipped_nodes_in_partial_pipeline_run(self): """Tests that nodes marked to be skipped in a partial pipeline run have the right node state.""" with self._mlmd_connection as m: pipeline = pipeline_pb2.Pipeline() pipeline.pipeline_info.id = 'pipeline1' pipeline.execution_mode = pipeline_pb2.Pipeline.SYNC pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' # Mark ExampleGen and Transform to be skipped. pipeline.nodes[0].pipeline_node.execution_options.skip.SetInParent( ) pipeline.nodes[1].pipeline_node.execution_options.skip.SetInParent( ) eg_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen') transform_node_uid = task_lib.NodeUid(pipeline_uid, 'Transform') trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer') pstate.PipelineState.new(m, pipeline) with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual( { eg_node_uid: pstate.NodeState(state=pstate.NodeState.COMPLETE), transform_node_uid: pstate.NodeState(state=pstate.NodeState.COMPLETE), trainer_node_uid: pstate.NodeState(state=pstate.NodeState.STARTED), }, pipeline_state.get_node_states_dict())
def test_pipeline_view_get_node_run_states(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' pipeline.nodes.add().pipeline_node.node_info.id = 'Pusher' eg_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen') transform_node_uid = task_lib.NodeUid(pipeline_uid, 'Transform') trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer') evaluator_node_uid = task_lib.NodeUid(pipeline_uid, 'Evaluator') with pstate.PipelineState.new(m, pipeline) as pipeline_state: with pipeline_state.node_state_update_context( eg_node_uid) as node_state: node_state.update(pstate.NodeState.RUNNING) with pipeline_state.node_state_update_context( transform_node_uid) as node_state: node_state.update(pstate.NodeState.STARTING) with pipeline_state.node_state_update_context( trainer_node_uid) as node_state: node_state.update(pstate.NodeState.STARTED) with pipeline_state.node_state_update_context( evaluator_node_uid) as node_state: node_state.update( pstate.NodeState.FAILED, status_lib.Status(code=status_lib.Code.ABORTED, message='foobar error')) [view] = pstate.PipelineView.load_all( m, task_lib.PipelineUid.from_pipeline(pipeline)) run_states_dict = view.get_node_run_states() self.assertEqual( run_state_pb2.RunState(state=run_state_pb2.RunState.RUNNING), run_states_dict['ExampleGen']) self.assertEqual( run_state_pb2.RunState(state=run_state_pb2.RunState.UNKNOWN), run_states_dict['Transform']) self.assertEqual( run_state_pb2.RunState(state=run_state_pb2.RunState.READY), run_states_dict['Trainer']) self.assertEqual( run_state_pb2.RunState( state=run_state_pb2.RunState.FAILED, status_code=run_state_pb2.RunState.StatusCodeValue( value=status_lib.Code.ABORTED), status_msg='foobar error'), run_states_dict['Evaluator']) self.assertEqual( run_state_pb2.RunState(state=run_state_pb2.RunState.READY), run_states_dict['Pusher'])
def test_stop_node_no_active_executions(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'), pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid) with self._mlmd_connection as m: pstate.PipelineState.new(m, pipeline) pipeline_ops.stop_node(m, node_uid) pipeline_state = pstate.PipelineState.load(m, pipeline_uid) # The node should be stop-initiated even when node is inactive to prevent # future triggers. with pipeline_state: self.assertEqual( status_lib.Code.CANCELLED, pipeline_state.node_stop_initiated_reason(node_uid).code) # Restart node. pipeline_state = pipeline_ops.initiate_node_start(m, node_uid) with pipeline_state: self.assertIsNone( pipeline_state.node_stop_initiated_reason(node_uid))
def test_stop_node_wait_for_inactivation(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'), pipeline) trainer = pipeline.nodes[2].pipeline_node test_utils.fake_component_output( self._mlmd_connection, trainer, active=True) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid) with self._mlmd_connection as m: pstate.PipelineState.new(m, pipeline).commit() def _inactivate(execution): time.sleep(2.0) with pipeline_ops._PIPELINE_OPS_LOCK: execution.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution]) execution = task_gen_utils.get_executions(m, trainer)[0] thread = threading.Thread( target=_inactivate, args=(copy.deepcopy(execution),)) thread.start() pipeline_ops.stop_node(m, node_uid, timeout_secs=5.0) thread.join() pipeline_state = pstate.PipelineState.load(m, pipeline_uid) self.assertEqual(status_lib.Code.CANCELLED, pipeline_state.node_stop_initiated_reason(node_uid).code) # Restart node. pipeline_state = pipeline_ops.initiate_node_start(m, node_uid) self.assertIsNone(pipeline_state.node_stop_initiated_reason(node_uid))
def test_task_ids(self): pipeline_uid = task_lib.PipelineUid(pipeline_id='pipeline') node_uid = task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer') exec_node_task = test_utils.create_exec_node_task(node_uid) self.assertEqual(('ExecNodeTask', node_uid), exec_node_task.task_id) cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid) self.assertEqual(('CancelNodeTask', node_uid), cancel_node_task.task_id)
def test_stop_node_wait_for_inactivation_timeout(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'), pipeline) trainer = pipeline.nodes[2].pipeline_node test_utils.fake_component_output( self._mlmd_connection, trainer, active=True) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid) with self._mlmd_connection as m: pstate.PipelineState.new(m, pipeline).commit() with self.assertRaisesRegex( status_lib.StatusNotOkError, 'Timed out.*waiting for execution inactivation.' ) as exception_context: pipeline_ops.stop_node(m, node_uid, timeout_secs=1.0) self.assertEqual(status_lib.Code.DEADLINE_EXCEEDED, exception_context.exception.code) # Even if `wait_for_inactivation` times out, the node should be stop # initiated to prevent future triggers. pipeline_state = pstate.PipelineState.load(m, pipeline_uid) self.assertEqual(status_lib.Code.CANCELLED, pipeline_state.node_stop_initiated_reason(node_uid).code)
def test_scheduler_not_found(self): task = test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'), node_id='Transform'), pipeline=self._pipeline) with self.assertRaisesRegex(ValueError, 'No task scheduler found'): ts.TaskSchedulerRegistry.create_task_scheduler( mock.Mock(), self._pipeline, task)
def _test_exec_node_task(node_id, pipeline_id, pipeline_run_id=None, pipeline=None): node_uid = task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id=pipeline_id, pipeline_run_id=pipeline_run_id), node_id=node_id) return test_utils.create_exec_node_task(node_uid, pipeline=pipeline)
def test_node_uid_from_pipeline_node(self): pipeline = pipeline_pb2.Pipeline() pipeline.pipeline_info.id = 'pipeline' node = pipeline_pb2.PipelineNode() node.node_info.id = 'Trainer' self.assertEqual( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'), node_id='Trainer'), task_lib.NodeUid.from_pipeline_node(pipeline, node))
def test_task_ids(self): node_uid = task_lib.NodeUid(pipeline_id='pipeline', pipeline_run_id='run0', node_id='Trainer') exec_node_task = task_lib.ExecNodeTask(node_uid=node_uid, execution_id=123) self.assertEqual(('ExecNodeTask', node_uid), exec_node_task.task_id) cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid) self.assertEqual(('CancelNodeTask', node_uid), cancel_node_task.task_id)
def test_task_type_ids(self): self.assertEqual('ExecNodeTask', task_lib.ExecNodeTask.task_type_id()) self.assertEqual('CancelNodeTask', task_lib.CancelNodeTask.task_type_id()) node_uid = task_lib.NodeUid(pipeline_id='pipeline', pipeline_run_id='run0', node_id='Trainer') exec_node_task = test_utils.create_exec_node_task(node_uid) self.assertEqual('ExecNodeTask', exec_node_task.task_type_id()) cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid) self.assertEqual('CancelNodeTask', cancel_node_task.task_type_id())
def test_node_uid_from_pipeline_node(self): pipeline = pipeline_pb2.Pipeline() pipeline.pipeline_info.id = 'pipeline' pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run0' node = pipeline_pb2.PipelineNode() node.node_info.id = 'Trainer' self.assertEqual( task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid( pipeline_id='pipeline', pipeline_run_id='run0'), node_id='Trainer'), task_lib.NodeUid.from_pipeline_node(pipeline, node))
def test_register_using_node_type_name(self): # Register a fake task scheduler. ts.TaskSchedulerRegistry.register(constants.IMPORTER_NODE_TYPE, _FakeTaskScheduler) # Create a task and verify that the correct scheduler is instantiated. task = test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'), node_id='Importer'), pipeline=self._pipeline) task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler( mock.Mock(), self._pipeline, task) self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
def test_exec_node_task_create(self): pipeline = pipeline_pb2.Pipeline() pipeline.pipeline_info.id = 'pipeline' pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run0' node = pipeline_pb2.PipelineNode() node.node_info.id = 'Trainer' self.assertEqual( task_lib.ExecNodeTask(node_uid=task_lib.NodeUid( pipeline_id='pipeline', pipeline_run_id='run0', node_id='Trainer'), execution_id=123), task_lib.ExecNodeTask.create(pipeline, node, 123))
def test_register_using_executor_spec_type_url(self): # Register a fake task scheduler. ts.TaskSchedulerRegistry.register(self._spec_type_url, _FakeTaskScheduler) # Create a task and verify that the correct scheduler is instantiated. task = test_utils.create_exec_node_task( node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'), node_id='Trainer'), pipeline=self._pipeline) task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler( mock.Mock(), self._pipeline, task) self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
def test_initiate_node_start_stop(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') node_uid = task_lib.NodeUid( node_id='Trainer', pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline)) with pstate.PipelineState.new(m, pipeline) as pipeline_state: with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STARTING) node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STARTING, node_state.state) # Reload from MLMD and verify node is started. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STARTING, node_state.state) # Set node state to STOPPING. status = status_lib.Status(code=status_lib.Code.ABORTED, message='foo bar') with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STOPPING, status) node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STOPPING, node_state.state) self.assertEqual(status, node_state.status) # Reload from MLMD and verify node is stopped. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STOPPING, node_state.state) self.assertEqual(status, node_state.status) # Set node state to STARTED. with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STARTED) node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STARTED, node_state.state) # Reload from MLMD and verify node is started. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STARTED, node_state.state)
def test_handling_finalize_node_task(self, task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED, message='foo bar') task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Transform')), task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid( pipeline_uid=pipeline_uid, node_id='Trainer'), status=finalize_reason) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) task_gen.return_value.generate.assert_called_once() task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid) # Load pipeline state and verify node stop initiation. with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual( finalize_reason, pipeline_state.node_stop_initiated_reason( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer')))
def test_get_node_states_dict(self): with self._mlmd_connection as m: pipeline = pipeline_pb2.Pipeline() pipeline.pipeline_info.id = 'pipeline1' pipeline.execution_mode = pipeline_pb2.Pipeline.SYNC pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' eg_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen') transform_node_uid = task_lib.NodeUid(pipeline_uid, 'Transform') trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer') evaluator_node_uid = task_lib.NodeUid(pipeline_uid, 'Evaluator') with pstate.PipelineState.new(m, pipeline) as pipeline_state: with pipeline_state.node_state_update_context( eg_node_uid) as node_state: node_state.update(pstate.NodeState.COMPLETE) with pipeline_state.node_state_update_context( transform_node_uid) as node_state: node_state.update(pstate.NodeState.RUNNING) with pipeline_state.node_state_update_context( trainer_node_uid) as node_state: node_state.update(pstate.NodeState.STARTING) with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual( { eg_node_uid: pstate.NodeState(state=pstate.NodeState.COMPLETE), transform_node_uid: pstate.NodeState(state=pstate.NodeState.RUNNING), trainer_node_uid: pstate.NodeState(state=pstate.NodeState.STARTING), evaluator_node_uid: pstate.NodeState(state=pstate.NodeState.STARTED), }, pipeline_state.get_node_states_dict())
def test_initiate_node_start_stop(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') node_uid = task_lib.NodeUid( node_id='Trainer', pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline)) with pstate.PipelineState.new(m, pipeline) as pipeline_state: pipeline_state.initiate_node_start(node_uid) self.assertIsNone( pipeline_state.node_stop_initiated_reason(node_uid)) # Reload from MLMD and verify node is started. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: self.assertIsNone( pipeline_state.node_stop_initiated_reason(node_uid)) # Stop the node. status = status_lib.Status(code=status_lib.Code.ABORTED, message='foo bar') pipeline_state.initiate_node_stop(node_uid, status) self.assertEqual( status, pipeline_state.node_stop_initiated_reason(node_uid)) # Reload from MLMD and verify node is stopped. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: self.assertEqual( status, pipeline_state.node_stop_initiated_reason(node_uid)) # Restart node. pipeline_state.initiate_node_start(node_uid) self.assertIsNone( pipeline_state.node_stop_initiated_reason(node_uid)) # Reload from MLMD and verify node is started. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: self.assertIsNone( pipeline_state.node_stop_initiated_reason(node_uid))
def test_registration_and_creation(self): # Create a pipeline IR containing deployment config for testing. deployment_config = pipeline_pb2.IntermediateDeploymentConfig() executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec( class_path='trainer.TrainerExecutor') deployment_config.executor_specs['Trainer'].Pack(executor_spec) pipeline = pipeline_pb2.Pipeline() pipeline.deployment_config.Pack(deployment_config) # Register a fake task scheduler. spec_type_url = deployment_config.executor_specs['Trainer'].type_url ts.TaskSchedulerRegistry.register(spec_type_url, _FakeTaskScheduler) # Create a task and verify that the correct scheduler is instantiated. task = test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_id='pipeline', pipeline_run_id=None, node_id='Trainer')) task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler( mock.Mock(), pipeline, task) self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
def test_initiate_node_start_stop(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') node_uid = task_lib.NodeUid( node_id='Trainer', pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline)) with pstate.PipelineState.new(m, pipeline) as pipeline_state: pipeline_state.initiate_node_start(node_uid) self.assertFalse( pipeline_state.is_node_stop_initiated(node_uid)) # Reload from MLMD and verify node is started. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: self.assertFalse( pipeline_state.is_node_stop_initiated(node_uid)) # Stop the node. pipeline_state.initiate_node_stop(node_uid) self.assertTrue( pipeline_state.is_node_stop_initiated(node_uid)) # Reload from MLMD and verify node is stopped. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: self.assertTrue( pipeline_state.is_node_stop_initiated(node_uid)) # Restart node. pipeline_state.initiate_node_start(node_uid) self.assertFalse( pipeline_state.is_node_stop_initiated(node_uid)) # Reload from MLMD and verify node is started. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: self.assertFalse( pipeline_state.is_node_stop_initiated(node_uid))
def create_node_uid(pipeline_id, node_id): """Creates node uid.""" return task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id), node_id=node_id)
def _test_cancel_task(node_id, pipeline_id, pipeline_run_id=None): node_uid = task_lib.NodeUid(pipeline_id=pipeline_id, pipeline_run_id=pipeline_run_id, node_id=node_id) return task_lib.CancelNodeTask(node_uid=node_uid)
def _test_cancel_node_task(node_id, pipeline_id): node_uid = task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id), node_id=node_id) return task_lib.CancelNodeTask(node_uid=node_uid)
def test_stop_initiated_async_pipelines(self, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline1 = _test_pipeline('pipeline1') pipeline1.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline1.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline1.nodes.add().pipeline_node.node_info.id = 'Evaluator' pipeline_ops.initiate_pipeline_start(m, pipeline1) pipeline1_execution = pipeline_ops._initiate_pipeline_stop( m, task_lib.PipelineUid.from_pipeline(pipeline1)) task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Transform'))) transform_task = task_queue.dequeue( ) # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.generate_tasks(m, task_queue) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # Simulate finishing the "Transform" ExecNodeTask. task_queue.task_done(transform_task) # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call(m, pipeline1, pipeline1.nodes[1].pipeline_node, mock.ANY, is_cancelled=True), mock.call(m, pipeline1, pipeline1.nodes[2].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `generate_tasks`. [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `generate_tasks` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.generate_tasks(m, task_queue) self.assertTrue(task_queue.is_empty()) [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state)
def _update_successful_nodes_cache(self, node_ids: Set[str]) -> None: for node_id in node_ids: node_uid = task_lib.NodeUid(pipeline_uid=self._pipeline_uid, node_id=node_id) _successful_nodes_cache[self._node_cache_key(node_uid)] = True
def _test_task(node_id, pipeline_id, pipeline_run_id=None): node_uid = task_lib.NodeUid(pipeline_id=pipeline_id, pipeline_run_id=pipeline_run_id, node_id=node_id) return task_lib.ExecNodeTask(node_uid=node_uid, execution_id=123)
def test_orchestrate_active_pipelines(self, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: # Sync and async active pipelines. async_pipelines = [ _test_pipeline('pipeline1'), _test_pipeline('pipeline2'), ] sync_pipelines = [ _test_pipeline('pipeline3', pipeline_pb2.Pipeline.SYNC), _test_pipeline('pipeline4', pipeline_pb2.Pipeline.SYNC), ] for pipeline in async_pipelines + sync_pipelines: pipeline_ops.initiate_pipeline_start(m, pipeline) # Active executions for active async pipelines. mock_async_task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( async_pipelines[0]), node_id='Transform')) ], [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( async_pipelines[1]), node_id='Trainer')) ], ] # Active executions for active sync pipelines. mock_sync_task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( sync_pipelines[0]), node_id='Trainer')) ], [ test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline( sync_pipelines[1]), node_id='Validator')) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count) self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count) # Verify that tasks are enqueued in the expected order. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline2', 'Trainer'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline3', 'Trainer'), task.node_uid) task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline4', 'Validator'), task.node_uid) self.assertTrue(task_queue.is_empty())
def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == 'ExampleGen') mock_service_job_manager.is_mixed_service_node.side_effect = ( lambda _, node_id: node_id == 'Transform') pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: pipeline_state.initiate_stop( status_lib.Status(code=status_lib.Code.CANCELLED)) pipeline_execution = pipeline_state.execution task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), node_id='Transform'))) transform_task = task_queue.dequeue() # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task( node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # stop_node_services should be called for ExampleGen which is a pure # service node. mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') mock_service_job_manager.reset_mock() task_queue.task_done(transform_task) # Pop out transform task. # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask (with is_cancelled=True) for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task.is_cancelled) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call( m, pipeline_state.pipeline, pipeline.nodes[2].pipeline_node, mock.ANY, is_cancelled=True), mock.call( m, pipeline_state.pipeline, pipeline.nodes[3].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `orchestrate`. [execution] = m.store.get_executions_by_id([pipeline_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `orchestrate` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) self.assertTrue(task_queue.is_empty()) [execution] = m.store.get_executions_by_id([pipeline_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state) # stop_node_services should be called on both ExampleGen and Transform # which are service nodes. mock_service_job_manager.stop_node_services.assert_has_calls( [mock.call(mock.ANY, 'ExampleGen'), mock.call(mock.ANY, 'Transform')], any_order=True)
def _test_task(node_id, pipeline_id, key=''): node_uid = task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id, key=key), node_id=node_id) return test_utils.create_exec_node_task(node_uid)