コード例 #1
0
ファイル: task_manager_test.py プロジェクト: suryaavala/tfx
 def _task_manager(self, task_queue):
     with tm.TaskManager(
             mock.Mock(),
             task_queue,
             max_active_task_schedulers=1000,
             max_dequeue_wait_secs=0.1,
             process_all_queued_tasks_before_exit=True) as task_manager:
         yield task_manager
コード例 #2
0
ファイル: task_manager_test.py プロジェクト: suryaavala/tfx
 def _run_task_manager(self):
     with self._mlmd_connection as m:
         with tm.TaskManager(
                 m,
                 self._task_queue,
                 1000,
                 max_dequeue_wait_secs=0.1,
                 process_all_queued_tasks_before_exit=True) as task_manager:
             pass
     return task_manager
コード例 #3
0
    def test_queue_multiplexing(self, mock_publish):
        # Create a pipeline IR containing deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='trainer.TrainerExecutor')
        deployment_config.executor_specs['Trainer'].Pack(executor_spec)
        deployment_config.executor_specs['Transform'].Pack(executor_spec)
        deployment_config.executor_specs['Evaluator'].Pack(executor_spec)
        pipeline = pipeline_pb2.Pipeline()
        pipeline.deployment_config.Pack(deployment_config)

        collector = _Collector()

        # Register a bunch of fake task schedulers.
        # Register fake task scheduler.
        ts.TaskSchedulerRegistry.register(
            deployment_config.executor_specs['Trainer'].type_url,
            functools.partial(_FakeTaskScheduler,
                              block_nodes={'Trainer', 'Transform'},
                              collector=collector))

        task_queue = tq.TaskQueue()

        # Enqueue some tasks.
        trainer_exec_task = _test_exec_task('Trainer', 'test-pipeline')
        task_queue.enqueue(trainer_exec_task)
        task_queue.enqueue(_test_cancel_task('Trainer', 'test-pipeline'))

        with tm.TaskManager(mock.Mock(),
                            pipeline,
                            task_queue,
                            max_active_task_schedulers=1000,
                            max_dequeue_wait_secs=0.1,
                            process_all_queued_tasks_before_exit=True):
            # Enqueue more tasks after task manager starts.
            transform_exec_task = _test_exec_task('Transform', 'test-pipeline')
            task_queue.enqueue(transform_exec_task)
            evaluator_exec_task = _test_exec_task('Evaluator', 'test-pipeline')
            task_queue.enqueue(evaluator_exec_task)
            task_queue.enqueue(_test_cancel_task('Transform', 'test-pipeline'))

        # Ensure that all exec and cancellation tasks were processed correctly.
        self.assertCountEqual(
            [trainer_exec_task, transform_exec_task, evaluator_exec_task],
            collector.scheduled_tasks)
        self.assertCountEqual([trainer_exec_task, transform_exec_task],
                              collector.cancelled_tasks)
        mock_publish.assert_has_calls([
            mock.call(mock.ANY, pipeline, trainer_exec_task, mock.ANY),
            mock.call(mock.ANY, pipeline, transform_exec_task, mock.ANY),
            mock.call(mock.ANY, pipeline, evaluator_exec_task, mock.ANY)
        ],
                                      any_order=True)