def _handle_task(self, task: task_lib.Task) -> None: """Dispatches task to the task specific handler.""" if task_lib.is_exec_node_task(task): self._handle_exec_node_task(typing.cast(task_lib.ExecNodeTask, task)) elif task_lib.is_cancel_node_task(task): self._handle_cancel_node_task(typing.cast(task_lib.CancelNodeTask, task)) else: raise RuntimeError('Cannot dispatch bad task: {}'.format(task))
def test_active_pipelines_with_stop_initiated_nodes(self, mock_gen_task_from_active, mock_async_task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline') pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == 'ExampleGen') example_gen_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[0].pipeline_node) transform_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[1].pipeline_node) transform_task = test_utils.create_exec_node_task( node_uid=transform_node_uid) trainer_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[2].pipeline_node) trainer_task = test_utils.create_exec_node_task(node_uid=trainer_node_uid) evaluator_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[3].pipeline_node) evaluator_task = test_utils.create_exec_node_task( node_uid=evaluator_node_uid) cancelled_evaluator_task = test_utils.create_exec_node_task( node_uid=evaluator_node_uid, is_cancelled=True) pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: # Stop example-gen, trainer and evaluator. pipeline_state.initiate_node_stop( example_gen_node_uid, status_lib.Status(code=status_lib.Code.CANCELLED)) pipeline_state.initiate_node_stop( trainer_node_uid, status_lib.Status(code=status_lib.Code.CANCELLED)) pipeline_state.initiate_node_stop( evaluator_node_uid, status_lib.Status(code=status_lib.Code.ABORTED)) task_queue = tq.TaskQueue() # Simulate a new transform execution being triggered. mock_async_task_gen.return_value.generate.return_value = [transform_task] # Simulate ExecNodeTask for trainer already present in the task queue. task_queue.enqueue(trainer_task) # Simulate Evaluator having an active execution in MLMD. mock_gen_task_from_active.side_effect = [evaluator_task] pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) self.assertEqual(1, mock_async_task_gen.return_value.generate.call_count) # stop_node_services should be called on example-gen which is a pure # service node. mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') # Verify that tasks are enqueued in the expected order: # Pre-existing trainer task. task = task_queue.dequeue() task_queue.task_done(task) self.assertEqual(trainer_task, task) # CancelNodeTask for trainer. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual(trainer_node_uid, task.node_uid) # ExecNodeTask with is_cancelled=True for evaluator. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(cancelled_evaluator_task, task) # ExecNodeTask for newly triggered transform node. task = task_queue.dequeue() task_queue.task_done(task) self.assertEqual(transform_task, task) # No more tasks. self.assertTrue(task_queue.is_empty())
def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == 'ExampleGen') mock_service_job_manager.is_mixed_service_node.side_effect = ( lambda _, node_id: node_id == 'Transform') pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: pipeline_state.initiate_stop( status_lib.Status(code=status_lib.Code.CANCELLED)) pipeline_execution = pipeline_state.execution task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), node_id='Transform'))) transform_task = task_queue.dequeue() # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task( node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # stop_node_services should be called for ExampleGen which is a pure # service node. mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') mock_service_job_manager.reset_mock() task_queue.task_done(transform_task) # Pop out transform task. # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask (with is_cancelled=True) for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task.is_cancelled) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call( m, pipeline_state.pipeline, pipeline.nodes[2].pipeline_node, mock.ANY, is_cancelled=True), mock.call( m, pipeline_state.pipeline, pipeline.nodes[3].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `orchestrate`. [execution] = m.store.get_executions_by_id([pipeline_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `orchestrate` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) self.assertTrue(task_queue.is_empty()) [execution] = m.store.get_executions_by_id([pipeline_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state) # stop_node_services should be called on both ExampleGen and Transform # which are service nodes. mock_service_job_manager.stop_node_services.assert_has_calls( [mock.call(mock.ANY, 'ExampleGen'), mock.call(mock.ANY, 'Transform')], any_order=True)
def test_stop_initiated_async_pipelines(self, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline1 = _test_pipeline('pipeline1') pipeline1.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline1.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline1.nodes.add().pipeline_node.node_info.id = 'Evaluator' pipeline_ops.initiate_pipeline_start(m, pipeline1) pipeline1_execution = pipeline_ops._initiate_pipeline_stop( m, task_lib.PipelineUid.from_pipeline(pipeline1)) task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Transform'))) transform_task = task_queue.dequeue( ) # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.generate_tasks(m, task_queue) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # Simulate finishing the "Transform" ExecNodeTask. task_queue.task_done(transform_task) # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call(m, pipeline1, pipeline1.nodes[1].pipeline_node, mock.ANY, is_cancelled=True), mock.call(m, pipeline1, pipeline1.nodes[2].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `generate_tasks`. [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `generate_tasks` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.generate_tasks(m, task_queue) self.assertTrue(task_queue.is_empty()) [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state)
def test_active_pipelines_with_stop_initiated_nodes( self, mock_gen_task_from_active, mock_async_task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline') pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' transform_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[0].pipeline_node) transform_task = test_utils.create_exec_node_task( node_uid=transform_node_uid) trainer_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[1].pipeline_node) trainer_task = test_utils.create_exec_node_task( node_uid=trainer_node_uid) evaluator_node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, pipeline.nodes[2].pipeline_node) evaluator_task = test_utils.create_exec_node_task( node_uid=evaluator_node_uid) cancelled_evaluator_task = test_utils.create_exec_node_task( node_uid=evaluator_node_uid, is_cancelled=True) pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: # Stop trainer and evaluator. pipeline_state.initiate_node_stop(trainer_node_uid) pipeline_state.initiate_node_stop(evaluator_node_uid) task_queue = tq.TaskQueue() # Simulate a new transform execution being triggered. mock_async_task_gen.return_value.generate.return_value = [ transform_task ] # Simulate ExecNodeTask for trainer already present in the task queue. task_queue.enqueue(trainer_task) # Simulate Evaluator having an active execution in MLMD. mock_gen_task_from_active.side_effect = [evaluator_task] pipeline_ops.orchestrate(m, task_queue) self.assertEqual( 1, mock_async_task_gen.return_value.generate.call_count) # Verify that tasks are enqueued in the expected order: # Pre-existing trainer task. task = task_queue.dequeue() task_queue.task_done(task) self.assertEqual(trainer_task, task) # CancelNodeTask for trainer. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual(trainer_node_uid, task.node_uid) # ExecNodeTask with is_cancelled=True for evaluator. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(cancelled_evaluator_task, task) # ExecNodeTask for newly triggered transform node. task = task_queue.dequeue() task_queue.task_done(task) self.assertEqual(transform_task, task) # No more tasks. self.assertTrue(task_queue.is_empty())