def test_action_repeated(self): mock_env = self._get_mock_env_episode() env = wrappers.ActionRepeat(mock_env, 3) env.reset() env.step([2]) env.step([3]) mock_env.step.assert_has_calls([mock.call([2])] + [mock.call([3])] * 3)
def testBlockingSave(self): saver = mock.create_autospec(policy_saver.PolicySaver, instance=True) async_saver = async_policy_saver.AsyncPolicySaver(saver) path1 = os.path.join(self.get_temp_dir(), 'save_model') path2 = os.path.join(self.get_temp_dir(), 'save_model2') self.evaluate(tf.compat.v1.global_variables_initializer()) async_saver.save(path1) async_saver.save(path2, blocking=True) saver.save.assert_has_calls([mock.call(path1), mock.call(path2)])
def test_action_stops_on_last(self): mock_env = self._get_mock_env_episode() env = wrappers.ActionRepeat(mock_env, 3) env.reset() env.step([2]) time_step = env.step([3]) mock_env.step.assert_has_calls([mock.call([2])] * 3 + [mock.call([3])]) self.assertEqual(7, time_step.reward) self.assertEqual([3], time_step.observation)
def test_task_handling(self, mock_publish): collector = _Collector() # Register a fake task scheduler. ts.TaskSchedulerRegistry.register( self._type_url, functools.partial( _FakeTaskScheduler, block_nodes={'Trainer', 'Transform'}, collector=collector)) task_queue = tq.TaskQueue() # Enqueue some tasks. trainer_exec_task = _test_exec_node_task('Trainer', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(trainer_exec_task) task_queue.enqueue(_test_cancel_node_task('Trainer', 'test-pipeline')) with self._task_manager(task_queue) as task_manager: # Enqueue more tasks after task manager starts. transform_exec_task = _test_exec_node_task('Transform', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(transform_exec_task) evaluator_exec_task = _test_exec_node_task('Evaluator', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(evaluator_exec_task) task_queue.enqueue(_test_cancel_node_task('Transform', 'test-pipeline')) self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Ensure that all exec and cancellation tasks were processed correctly. self.assertCountEqual( [trainer_exec_task, transform_exec_task, evaluator_exec_task], collector.scheduled_tasks) self.assertCountEqual([trainer_exec_task, transform_exec_task], collector.cancelled_tasks) mock_publish.assert_has_calls([ mock.call( mlmd_handle=mock.ANY, task=trainer_exec_task, result=mock.ANY), mock.call( mlmd_handle=mock.ANY, task=transform_exec_task, result=mock.ANY), mock.call( mlmd_handle=mock.ANY, task=evaluator_exec_task, result=mock.ANY), ], any_order=True)
def test_queue_multiplexing(self, mock_publish): # Create a pipeline IR containing deployment config for testing. deployment_config = pipeline_pb2.IntermediateDeploymentConfig() executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec( class_path='trainer.TrainerExecutor') deployment_config.executor_specs['Trainer'].Pack(executor_spec) deployment_config.executor_specs['Transform'].Pack(executor_spec) deployment_config.executor_specs['Evaluator'].Pack(executor_spec) pipeline = pipeline_pb2.Pipeline() pipeline.deployment_config.Pack(deployment_config) collector = _Collector() # Register a bunch of fake task schedulers. # Register fake task scheduler. ts.TaskSchedulerRegistry.register( deployment_config.executor_specs['Trainer'].type_url, functools.partial(_FakeTaskScheduler, block_nodes={'Trainer', 'Transform'}, collector=collector)) task_queue = tq.TaskQueue() # Enqueue some tasks. trainer_exec_task = _test_exec_task('Trainer', 'test-pipeline') task_queue.enqueue(trainer_exec_task) task_queue.enqueue(_test_cancel_task('Trainer', 'test-pipeline')) with tm.TaskManager(mock.Mock(), pipeline, task_queue, max_active_task_schedulers=1000, max_dequeue_wait_secs=0.1, process_all_queued_tasks_before_exit=True): # Enqueue more tasks after task manager starts. transform_exec_task = _test_exec_task('Transform', 'test-pipeline') task_queue.enqueue(transform_exec_task) evaluator_exec_task = _test_exec_task('Evaluator', 'test-pipeline') task_queue.enqueue(evaluator_exec_task) task_queue.enqueue(_test_cancel_task('Transform', 'test-pipeline')) # Ensure that all exec and cancellation tasks were processed correctly. self.assertCountEqual( [trainer_exec_task, transform_exec_task, evaluator_exec_task], collector.scheduled_tasks) self.assertCountEqual([trainer_exec_task, transform_exec_task], collector.cancelled_tasks) mock_publish.assert_has_calls([ mock.call(mock.ANY, pipeline, trainer_exec_task, mock.ANY), mock.call(mock.ANY, pipeline, transform_exec_task, mock.ANY), mock.call(mock.ANY, pipeline, evaluator_exec_task, mock.ANY) ], any_order=True)
def testBlockingCheckpointSave(self): saver = mock.create_autospec(policy_saver.PolicySaver, instance=True) async_saver = async_policy_saver.AsyncPolicySaver(saver) path1 = os.path.join(self.get_temp_dir(), 'save_model') path2 = os.path.join(self.get_temp_dir(), 'save_model2') self.evaluate(tf.compat.v1.global_variables_initializer()) async_saver.save_checkpoint(path1) async_saver.save_checkpoint(path2, blocking=True) saver.save_checkpoint.assert_has_calls([mock.call(path1), mock.call(path2)]) # Have to close the saver to avoid hanging threads that will prevent OSS # tests from finishing. async_saver.close()
def test_after_train_step_fn_with_fresh_data_only(self, create_strategy_fn): strategy = create_strategy_fn() with strategy.scope(): # Prepare the test context context. train_step = train_utils.create_train_step() train_step.assign(225) train_steps_per_policy_update = 100 # Create the after train function to test, and the test input. after_train_step_fn = ( train_utils.create_staleness_metrics_after_train_step_fn( train_step, train_steps_per_policy_update=train_steps_per_policy_update )) observation_train_steps = np.array([[200], [200], [200]], dtype=np.int64) # Define the expectations (expected scalar summary calls). expected_scalar_summary_calls = [ mock.call(name='staleness/max_train_step_delta_in_batch', data=0, step=225), mock.call(name='staleness/max_policy_update_delta_in_batch', data=0, step=225), mock.call(name='staleness/num_stale_obserations_in_batch', data=0, step=225) ] # Call the after train function and check the expectations. with mock.patch.object(tf.summary, 'scalar', autospec=True) as mock_scalar_summary: # Call the `after_train_function` on the test input. Assumed the # observation train steps are stored in the field `priority` of the # the sample info of Reverb. strategy.run(after_train_step_fn, args=((None, reverb.replay_sample.SampleInfo( key=None, probability=None, table_size=None, priority=observation_train_steps)), None)) # Check if the expected calls happened on the scalar summary. mock_scalar_summary.assert_has_calls( expected_scalar_summary_calls, any_order=False)
def test_exceptions_are_surfaced(self, mock_publish): def _publish(**kwargs): task = kwargs['task'] assert task_lib.is_exec_node_task(task) if task.node_uid.node_id == 'Transform': raise ValueError('test error') return mock.DEFAULT mock_publish.side_effect = _publish collector = _Collector() # Register a fake task scheduler. ts.TaskSchedulerRegistry.register( self._type_url, functools.partial( _FakeTaskScheduler, block_nodes={}, collector=collector)) task_queue = tq.TaskQueue() with self._task_manager(task_queue) as task_manager: transform_task = _test_exec_node_task('Transform', 'test-pipeline', pipeline=self._pipeline) trainer_task = _test_exec_node_task('Trainer', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(transform_task) task_queue.enqueue(trainer_task) self.assertTrue(task_manager.done()) exception = task_manager.exception() self.assertIsNotNone(exception) self.assertIsInstance(exception, tm.TasksProcessingError) self.assertLen(exception.errors, 1) self.assertEqual('test error', str(exception.errors[0])) self.assertCountEqual([transform_task, trainer_task], collector.scheduled_tasks) mock_publish.assert_has_calls([ mock.call( mlmd_handle=mock.ANY, task=transform_task, result=mock.ANY), mock.call( mlmd_handle=mock.ANY, task=trainer_task, result=mock.ANY), ], any_order=True)
def testLifeLoss(self): self._setup_mocks() with self.cached_session() as sess: self.trainer._initialize_graph(sess) time_step = ts_restart(0) # Run a regular step. self.trainer._env.step.return_value = ts_transition(1) time_step = self.trainer._collect_step(time_step, self.metric_observers, train=True) self.trainer._replay_buffer.add_batch.reset_mock() self.observer.reset_mock() # Lose a life, but not the end of a game. self.trainer._env.step.side_effect = [ ts_termination(2), ts_transition(3), ] time_step = self.trainer._collect_step(time_step, self.metric_observers, train=True) self.assertTrue(time_step.is_mid()) expected_rb_calls = [ mock.call(trajectory_last(1, discount=1.0)), mock.call(trajectory_first(2)) ] self.assertEqual( expected_rb_calls, self.trainer._replay_buffer.add_batch.call_args_list) expected_observer_calls = [ mock.call(trajectory_mid(1)), mock.call(trajectory_mid(2)), ] self.assertEqual(expected_observer_calls, self.observer.call_args_list)
def test_schedule_successful_task_should_send_pubsub_message(self): job = self._define_job_with_two_dependent_tasks() scheduler = job.make_scheduler() job.start() message = json.dumps({'id': job.id}).encode('utf-8') call = [mock.call(self.topic_path, data=message)] # When start() is called, there should be #{max_parallel_tasks} pubsub # messages sent to pubsub. self.mock_pubsub.publish.assert_has_calls(call * self.max_parallel_tasks) self.assertEqual(self.mock_pubsub.publish.call_count, self.max_parallel_tasks) self._call_job_scheduler(job, scheduler) # When task1 finishes, there should be another #{max_parallel_tasks} pubsub # messages sent to pubsub to trigger subsequent tasks. self.mock_pubsub.publish.assert_has_calls( call * (self.max_parallel_tasks * 2)) self.assertEqual(self.mock_pubsub.publish.call_count, self.max_parallel_tasks * 2)
def test_upload_directory(self, mock_walk, mock_upload_file): destination_blob_path = 'dir1' file_structure = [['/tmp/dir1', ['dir2', 'dir3'], ['file1', 'file2']], ['/tmp/dir1/dir2', [], ['file3']], ['/tmp/dir1/dir3', ['dir4'], ['file4']], ['/tmp/dir1/dir3/dir4', [], []]] mock_walk.return_value = file_structure calls = [] for (root, _, files) in file_structure: for file in files: source_file = root + '/' + file call = mock.call( self.cloud_storage_obj, root + '/' + file, self.mock_bucket, source_file.replace(self.source_directory_path, destination_blob_path)) calls.append(call) self.cloud_storage_obj.upload_directory(self.source_directory_path, self.mock_bucket, destination_blob_path) mock_upload_file.assert_has_calls(calls)
def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' mock_service_job_manager = mock.create_autospec( service_jobs.ServiceJobManager, instance=True) mock_service_job_manager.is_pure_service_node.side_effect = ( lambda _, node_id: node_id == 'ExampleGen') mock_service_job_manager.is_mixed_service_node.side_effect = ( lambda _, node_id: node_id == 'Transform') pipeline_ops.initiate_pipeline_start(m, pipeline) with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: pipeline_state.initiate_stop( status_lib.Status(code=status_lib.Code.CANCELLED)) pipeline_execution = pipeline_state.execution task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task( task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), node_id='Transform'))) transform_task = task_queue.dequeue() # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task( node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # stop_node_services should be called for ExampleGen which is a pure # service node. mock_service_job_manager.stop_node_services.assert_called_once_with( mock.ANY, 'ExampleGen') mock_service_job_manager.reset_mock() task_queue.task_done(transform_task) # Pop out transform task. # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask (with is_cancelled=True) for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task.is_cancelled) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call( m, pipeline_state.pipeline, pipeline.nodes[2].pipeline_node, mock.ANY, is_cancelled=True), mock.call( m, pipeline_state.pipeline, pipeline.nodes[3].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `orchestrate`. [execution] = m.store.get_executions_by_id([pipeline_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `orchestrate` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager) self.assertTrue(task_queue.is_empty()) [execution] = m.store.get_executions_by_id([pipeline_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state) # stop_node_services should be called on both ExampleGen and Transform # which are service nodes. mock_service_job_manager.stop_node_services.assert_has_calls( [mock.call(mock.ANY, 'ExampleGen'), mock.call(mock.ANY, 'Transform')], any_order=True)
def test_stop_initiated_async_pipelines(self, mock_gen_task_from_active, mock_async_task_gen, mock_sync_task_gen): with self._mlmd_connection as m: pipeline1 = _test_pipeline('pipeline1') pipeline1.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline1.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline1.nodes.add().pipeline_node.node_info.id = 'Evaluator' pipeline_ops.initiate_pipeline_start(m, pipeline1) pipeline1_execution = pipeline_ops._initiate_pipeline_stop( m, task_lib.PipelineUid.from_pipeline(pipeline1)) task_queue = tq.TaskQueue() # For the stop-initiated pipeline, "Transform" execution task is in queue, # "Trainer" has an active execution in MLMD but no task in queue, # "Evaluator" has no active execution. task_queue.enqueue( test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Transform'))) transform_task = task_queue.dequeue( ) # simulates task being processed mock_gen_task_from_active.side_effect = [ test_utils.create_exec_node_task(node_uid=task_lib.NodeUid( pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1', pipeline_run_id=None), node_id='Trainer'), is_cancelled=True), None, None, None, None ] pipeline_ops.generate_tasks(m, task_queue) # There are no active pipelines so these shouldn't be called. mock_async_task_gen.assert_not_called() mock_sync_task_gen.assert_not_called() # Simulate finishing the "Transform" ExecNodeTask. task_queue.task_done(transform_task) # CancelNodeTask for the "Transform" ExecNodeTask should be next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_cancel_node_task(task)) self.assertEqual('Transform', task.node_uid.node_id) # ExecNodeTask for "Trainer" is next. task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual('Trainer', task.node_uid.node_id) self.assertTrue(task_queue.is_empty()) mock_gen_task_from_active.assert_has_calls([ mock.call(m, pipeline1, pipeline1.nodes[1].pipeline_node, mock.ANY, is_cancelled=True), mock.call(m, pipeline1, pipeline1.nodes[2].pipeline_node, mock.ANY, is_cancelled=True) ]) self.assertEqual(2, mock_gen_task_from_active.call_count) # Pipeline execution should continue to be active since active node # executions were found in the last call to `generate_tasks`. [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertTrue(execution_lib.is_execution_active(execution)) # Call `generate_tasks` again; this time there are no more active node # executions so the pipeline should be marked as cancelled. pipeline_ops.generate_tasks(m, task_queue) self.assertTrue(task_queue.is_empty()) [execution ] = m.store.get_executions_by_id([pipeline1_execution.id]) self.assertEqual(metadata_store_pb2.Execution.CANCELED, execution.last_known_state)
def test_task_generation(self, use_task_queue): """Tests async pipeline task generation. Args: use_task_queue: If task queue is enabled, new tasks are only generated if a task with the same task_id does not already exist in the queue. `use_task_queue=False` is useful to test the case of task generation when task queue is empty (for eg: due to orchestrator restart). """ # Simulate that ExampleGen has already completed successfully. test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Generate once. [update_example_gen_task, update_transform_task, exec_transform_task] = self._generate_and_test( use_task_queue, num_initial_executions=1, num_tasks_generated=3, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._transform]) self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task)) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_transform_task)) self.assertEqual(pstate.NodeState.RUNNING, update_transform_task.state) self.assertTrue(task_lib.is_exec_node_task(exec_transform_task)) self._mock_service_job_manager.ensure_node_services.assert_has_calls([ mock.call(mock.ANY, self._example_gen.node_info.id), mock.call(mock.ANY, self._transform.node_info.id) ]) # No new effects if generate called again. tasks = self._generate_and_test( use_task_queue, num_initial_executions=2, num_tasks_generated=1 if use_task_queue else 3, num_new_executions=0, num_active_executions=1, expected_exec_nodes=[] if use_task_queue else [self._transform]) if not use_task_queue: exec_transform_task = tasks[2] # Mark transform execution complete. self._finish_node_execution(use_task_queue, exec_transform_task) # Trainer execution task should be generated next. [ update_example_gen_task, update_transform_task, update_trainer_task, exec_trainer_task ] = self._generate_and_test( use_task_queue, num_initial_executions=2, num_tasks_generated=4, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._trainer]) self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task)) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_transform_task)) self.assertEqual(pstate.NodeState.STARTED, update_transform_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_trainer_task)) self.assertEqual(pstate.NodeState.RUNNING, update_trainer_task.state) self.assertTrue(task_lib.is_exec_node_task(exec_trainer_task)) # Mark the trainer execution complete. self._finish_node_execution(use_task_queue, exec_trainer_task) # Only UpdateNodeStateTask are generated as there are no new inputs. tasks = self._generate_and_test( use_task_queue, num_initial_executions=3, num_tasks_generated=3, num_new_executions=0, num_active_executions=0) for task in tasks: self.assertTrue(task_lib.is_update_node_state_task(task)) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) # Fake another ExampleGen run. test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Both transform and trainer tasks should be generated as they both find # new inputs. [ update_example_gen_task, update_transform_task, exec_transform_task, update_trainer_task, exec_trainer_task ] = self._generate_and_test( use_task_queue, num_initial_executions=4, num_tasks_generated=5, num_new_executions=2, num_active_executions=2, expected_exec_nodes=[self._transform, self._trainer]) self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task)) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_transform_task)) self.assertEqual(pstate.NodeState.RUNNING, update_transform_task.state) self.assertTrue(task_lib.is_exec_node_task(exec_transform_task)) self.assertTrue(task_lib.is_update_node_state_task(update_trainer_task)) self.assertEqual(pstate.NodeState.RUNNING, update_trainer_task.state) self.assertTrue(task_lib.is_exec_node_task(exec_trainer_task)) # Re-generation will produce the same tasks when task queue disabled. tasks = self._generate_and_test( use_task_queue, num_initial_executions=6, num_tasks_generated=1 if use_task_queue else 5, num_new_executions=0, num_active_executions=2, expected_exec_nodes=[] if use_task_queue else [self._transform, self._trainer]) if not use_task_queue: self.assertTrue(task_lib.is_update_node_state_task(tasks[0])) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_update_node_state_task(tasks[1])) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_exec_node_task(tasks[2])) self.assertTrue(task_lib.is_update_node_state_task(tasks[3])) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_exec_node_task(tasks[4])) exec_transform_task = tasks[2] exec_trainer_task = tasks[4] else: self.assertTrue(task_lib.is_update_node_state_task(tasks[0])) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) # Mark transform execution complete. self._finish_node_execution(use_task_queue, exec_transform_task) # Mark the trainer execution complete. self._finish_node_execution(use_task_queue, exec_trainer_task) # Trainer should be triggered again due to transform producing new output. [ update_example_gen_task, update_transform_task, update_trainer_task, exec_trainer_task ] = self._generate_and_test( use_task_queue, num_initial_executions=6, num_tasks_generated=4, num_new_executions=1, num_active_executions=1, expected_exec_nodes=[self._trainer]) self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task)) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_transform_task)) self.assertEqual(pstate.NodeState.STARTED, update_transform_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_trainer_task)) self.assertEqual(pstate.NodeState.RUNNING, update_trainer_task.state) self.assertTrue(task_lib.is_exec_node_task(exec_trainer_task)) # Finally, no new tasks once trainer completes. self._finish_node_execution(use_task_queue, exec_trainer_task) [update_example_gen_task, update_transform_task, update_trainer_task] = self._generate_and_test( use_task_queue, num_initial_executions=7, num_tasks_generated=3, num_new_executions=0, num_active_executions=0) self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task)) self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_transform_task)) self.assertEqual(pstate.NodeState.STARTED, update_transform_task.state) self.assertTrue(task_lib.is_update_node_state_task(update_trainer_task)) self.assertEqual(pstate.NodeState.STARTED, update_trainer_task.state) if use_task_queue: self.assertTrue(self._task_queue.is_empty())
def test_task_handling(self, mock_publish): collector = _Collector() # Register a fake task scheduler. ts.TaskSchedulerRegistry.register( self._type_url, functools.partial(_FakeTaskScheduler, block_nodes={'Trainer', 'Transform', 'Pusher'}, collector=collector)) task_queue = tq.TaskQueue() # Enqueue some tasks. trainer_exec_task = _test_exec_node_task('Trainer', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(trainer_exec_task) task_queue.enqueue(_test_cancel_node_task('Trainer', 'test-pipeline')) with self._task_manager(task_queue) as task_manager: # Enqueue more tasks after task manager starts. transform_exec_task = _test_exec_node_task('Transform', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(transform_exec_task) evaluator_exec_task = _test_exec_node_task('Evaluator', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(evaluator_exec_task) task_queue.enqueue( _test_cancel_node_task('Transform', 'test-pipeline')) pusher_exec_task = _test_exec_node_task('Pusher', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(pusher_exec_task) task_queue.enqueue( _test_cancel_node_task('Pusher', 'test-pipeline', pause=True)) self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Ensure that all exec and cancellation tasks were processed correctly. self.assertCountEqual([ trainer_exec_task, transform_exec_task, evaluator_exec_task, pusher_exec_task, ], collector.scheduled_tasks) self.assertCountEqual([ trainer_exec_task, transform_exec_task, pusher_exec_task, ], collector.cancelled_tasks) result_ok = ts.TaskSchedulerResult(status=status_lib.Status( code=status_lib.Code.OK, message='_FakeTaskScheduler result')) result_cancelled = ts.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.CANCELLED, message='_FakeTaskScheduler result')) mock_publish.assert_has_calls([ mock.call(mlmd_handle=mock.ANY, task=trainer_exec_task, result=result_cancelled), mock.call(mlmd_handle=mock.ANY, task=transform_exec_task, result=result_cancelled), mock.call(mlmd_handle=mock.ANY, task=evaluator_exec_task, result=result_ok), ], any_order=True) # It is expected that publish is not called for Pusher because it was # cancelled with pause=True so there must be only 3 calls. self.assertLen(mock_publish.mock_calls, 3)