def _task_manager(self, task_queue): with tm.TaskManager( mock.Mock(), task_queue, max_active_task_schedulers=1000, max_dequeue_wait_secs=0.1, process_all_queued_tasks_before_exit=True) as task_manager: yield task_manager
def _run_task_manager(self): with self._mlmd_connection as m: with tm.TaskManager( m, self._task_queue, 1000, max_dequeue_wait_secs=0.1, process_all_queued_tasks_before_exit=True) as task_manager: pass return task_manager
def test_queue_multiplexing(self, mock_publish): # Create a pipeline IR containing deployment config for testing. deployment_config = pipeline_pb2.IntermediateDeploymentConfig() executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec( class_path='trainer.TrainerExecutor') deployment_config.executor_specs['Trainer'].Pack(executor_spec) deployment_config.executor_specs['Transform'].Pack(executor_spec) deployment_config.executor_specs['Evaluator'].Pack(executor_spec) pipeline = pipeline_pb2.Pipeline() pipeline.deployment_config.Pack(deployment_config) collector = _Collector() # Register a bunch of fake task schedulers. # Register fake task scheduler. ts.TaskSchedulerRegistry.register( deployment_config.executor_specs['Trainer'].type_url, functools.partial(_FakeTaskScheduler, block_nodes={'Trainer', 'Transform'}, collector=collector)) task_queue = tq.TaskQueue() # Enqueue some tasks. trainer_exec_task = _test_exec_task('Trainer', 'test-pipeline') task_queue.enqueue(trainer_exec_task) task_queue.enqueue(_test_cancel_task('Trainer', 'test-pipeline')) with tm.TaskManager(mock.Mock(), pipeline, task_queue, max_active_task_schedulers=1000, max_dequeue_wait_secs=0.1, process_all_queued_tasks_before_exit=True): # Enqueue more tasks after task manager starts. transform_exec_task = _test_exec_task('Transform', 'test-pipeline') task_queue.enqueue(transform_exec_task) evaluator_exec_task = _test_exec_task('Evaluator', 'test-pipeline') task_queue.enqueue(evaluator_exec_task) task_queue.enqueue(_test_cancel_task('Transform', 'test-pipeline')) # Ensure that all exec and cancellation tasks were processed correctly. self.assertCountEqual( [trainer_exec_task, transform_exec_task, evaluator_exec_task], collector.scheduled_tasks) self.assertCountEqual([trainer_exec_task, transform_exec_task], collector.cancelled_tasks) mock_publish.assert_has_calls([ mock.call(mock.ANY, pipeline, trainer_exec_task, mock.ANY), mock.call(mock.ANY, pipeline, transform_exec_task, mock.ANY), mock.call(mock.ANY, pipeline, evaluator_exec_task, mock.ANY) ], any_order=True)