def schedule(self) -> task_scheduler.TaskSchedulerResult: def _as_dict(proto_map) -> Dict[str, types.Property]: return { k: data_types_utils.get_value(v) for k, v in proto_map.items() } pipeline_node = self.task.get_pipeline_node() output_spec = pipeline_node.outputs.outputs[importer.IMPORT_RESULT_KEY] properties = _as_dict(output_spec.artifact_spec.additional_properties) custom_properties = _as_dict( output_spec.artifact_spec.additional_custom_properties) output_artifacts = importer.generate_output_dict( metadata_handler=self.mlmd_handle, uri=str(self.task.exec_properties[importer.SOURCE_URI_KEY]), properties=properties, custom_properties=custom_properties, reimport=bool( self.task.exec_properties[importer.REIMPORT_OPTION_KEY]), output_artifact_class=types.Artifact( output_spec.artifact_spec.type).type, mlmd_artifact_type=output_spec.artifact_spec.type) return task_scheduler.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output=task_scheduler.ImporterNodeOutput( output_artifacts=output_artifacts))
def _process_exec_node_task(self, scheduler: ts.TaskScheduler, task: task_lib.ExecNodeTask) -> None: """Processes an `ExecNodeTask` using the given task scheduler.""" # This is a blocking call to the scheduler which can take a long time to # complete for some types of task schedulers. The scheduler is expected to # handle any internal errors gracefully and return the result with an error # status. But in case the scheduler raises an exception, it is considered # a failed execution and MLMD is updated accordingly. try: result = scheduler.schedule() except Exception as e: # pylint: disable=broad-except logging.exception( 'Exception raised by task scheduler; node uid: %s', task.node_uid) result = ts.TaskSchedulerResult(status=status_lib.Status( code=status_lib.Code.ABORTED, message=str(e))) logging.info( 'For ExecNodeTask id: %s, task-scheduler result status: %s', task.task_id, result.status) _publish_execution_results(mlmd_handle=self._mlmd_handle, task=task, result=result) with self._publish_time_lock: self._last_mlmd_publish_time = time.time() with self._tm_lock: del self._scheduler_by_node_uid[task.node_uid] self._task_queue.task_done(task)
def test_executor_failure(self): # Register a fake task scheduler that returns success but the executor # was cancelled. self._register_task_scheduler( ts.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output=ts.ExecutorNodeOutput( executor_output=_make_executor_output( self._task, code=status_lib.Code.FAILED_PRECONDITION, msg='foobar error')))) task_manager = self._run_task_manager() self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Check that the task was processed and MLMD execution marked failed. self.assertTrue(self._task_queue.is_empty()) execution = self._get_execution() self.assertEqual(metadata_store_pb2.Execution.FAILED, execution.last_known_state) self.assertEqual( 'foobar error', data_types_utils.get_metadata_value(execution.custom_properties[ constants.EXECUTION_ERROR_MSG_KEY])) # Check that stateful working dir, tmp_dir and output artifact URI are # removed. self.assertFalse(os.path.exists(self._task.stateful_working_dir)) self.assertFalse(os.path.exists(self._task.tmp_dir)) self.assertFalse(os.path.exists(self._output_artifact_uri))
def schedule(self): logging.info('_FakeTaskScheduler: scheduling task: %s', self.task) self._collector.add_scheduled_task(self.task) if self.task.node_uid.node_id in self._block_nodes: self._stop_event.wait() return ts.TaskSchedulerResult( executor_output=execution_result_pb2.ExecutorOutput())
def schedule(self) -> task_scheduler.TaskSchedulerResult: while not self._cancel.wait(_POLLING_INTERVAL_SECS): with mlmd_state.mlmd_execution_atomic_op( mlmd_handle=self.mlmd_handle, execution_id=self.task.execution_id) as execution: node_state_mlmd_value = execution.custom_properties.get( NODE_STATE_PROPERTY_KEY) node_state = ManualNodeState.from_mlmd_value( node_state_mlmd_value) if node_state.state == ManualNodeState.COMPLETED: return task_scheduler.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output=task_scheduler.ExecutorNodeOutput()) return task_scheduler.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.CANCELLED), output=task_scheduler.ExecutorNodeOutput())
def schedule(self): logging.info('_FakeTaskScheduler: scheduling task: %s', self.task) self._collector.add_scheduled_task(self.task) if self.task.node_uid.node_id in self._block_nodes: self._cancel.wait() code = status_lib.Code.CANCELLED else: code = status_lib.Code.OK return ts.TaskSchedulerResult(status=status_lib.Status( code=code, message='_FakeTaskScheduler result'))
def schedule(self) -> ts.TaskSchedulerResult: logging.info('Processing ExecNodeTask: %s', self.task) executor_output = execution_result_pb2.ExecutorOutput() executor_output.execution_result.code = status_lib.Code.OK for key, artifacts in self.task.output_artifacts.items(): for artifact in artifacts: executor_output.output_artifacts[key].artifacts.add().CopyFrom( artifact.mlmd_artifact) result = ts.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output=ts.ExecutorNodeOutput(executor_output=executor_output)) logging.info('Result: %s', result) return result
def test_exceptions_are_surfaced(self, mock_publish): def _publish(**kwargs): task = kwargs['task'] assert task_lib.is_exec_node_task(task) if task.node_uid.node_id == 'Transform': raise ValueError('test error') return mock.DEFAULT mock_publish.side_effect = _publish collector = _Collector() # Register a fake task scheduler. ts.TaskSchedulerRegistry.register( self._type_url, functools.partial(_FakeTaskScheduler, block_nodes={}, collector=collector)) task_queue = tq.TaskQueue() with self._task_manager(task_queue) as task_manager: transform_task = _test_exec_node_task('Transform', 'test-pipeline', pipeline=self._pipeline) trainer_task = _test_exec_node_task('Trainer', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(transform_task) task_queue.enqueue(trainer_task) self.assertTrue(task_manager.done()) exception = task_manager.exception() self.assertIsNotNone(exception) self.assertIsInstance(exception, tm.TasksProcessingError) self.assertLen(exception.errors, 1) self.assertEqual('test error', str(exception.errors[0])) self.assertCountEqual([transform_task, trainer_task], collector.scheduled_tasks) result_ok = ts.TaskSchedulerResult(status=status_lib.Status( code=status_lib.Code.OK, message='_FakeTaskScheduler result')) mock_publish.assert_has_calls([ mock.call( mlmd_handle=mock.ANY, task=transform_task, result=result_ok), mock.call( mlmd_handle=mock.ANY, task=trainer_task, result=result_ok), ], any_order=True)
def test_scheduler_failure(self): # Register a fake task scheduler that returns a failure status. self._register_task_scheduler( ts.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.ABORTED), executor_output=None)) task_manager = self._run_task_manager() self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Check that the task was processed and MLMD execution marked failed. self.assertTrue(self._task_queue.is_empty()) execution = self._get_execution() self.assertEqual(metadata_store_pb2.Execution.FAILED, execution.last_known_state)
def test_successful_execution_resulting_in_output_artifacts(self): # Register a fake task scheduler that returns a successful execution result # and `OK` task scheduler status. self._register_task_scheduler( ts.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output_artifacts=self._task.output_artifacts)) task_manager = self._run_task_manager() self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Check that the task was processed and MLMD execution marked successful. self.assertTrue(self._task_queue.is_empty()) execution = self._get_execution() self.assertEqual(metadata_store_pb2.Execution.COMPLETE, execution.last_known_state)
def test_scheduler_failure(self): # Register a fake task scheduler that returns a failure status. self._register_task_scheduler( ts.TaskSchedulerResult(status=status_lib.Status( code=status_lib.Code.ABORTED, message='foobar error'), executor_output=None)) task_manager = self._run_task_manager() self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Check that the task was processed and MLMD execution marked failed. self.assertTrue(self._task_queue.is_empty()) execution = self._get_execution() self.assertEqual(metadata_store_pb2.Execution.FAILED, execution.last_known_state) self.assertEqual( 'foobar error', data_types_utils.get_metadata_value(execution.custom_properties[ constants.EXECUTION_ERROR_MSG_KEY]))
def test_successful_execution_resulting_in_output_artifacts(self): # Register a fake task scheduler that returns a successful execution result # and `OK` task scheduler status. self._register_task_scheduler( ts.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output=ts.ImporterNodeOutput( output_artifacts=self._task.output_artifacts))) task_manager = self._run_task_manager() self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Check that the task was processed and MLMD execution marked successful. self.assertTrue(self._task_queue.is_empty()) execution = self._get_execution() self.assertEqual(metadata_store_pb2.Execution.COMPLETE, execution.last_known_state) # Check that stateful working dir and tmp_dir are removed. self.assertFalse(os.path.exists(self._task.stateful_working_dir)) self.assertFalse(os.path.exists(self._task.tmp_dir)) # Output artifact URI remains as execution was successful. self.assertTrue(os.path.exists(self._output_artifact_uri))
def schedule(self) -> task_scheduler.TaskSchedulerResult: return task_scheduler.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output=task_scheduler.ResolverNodeOutput( resolved_input_artifacts=self.task.input_artifacts))
def schedule(self): return ts.TaskSchedulerResult( executor_output=execution_result_pb2.ExecutorOutput())
def test_task_handling(self, mock_publish): collector = _Collector() # Register a fake task scheduler. ts.TaskSchedulerRegistry.register( self._type_url, functools.partial(_FakeTaskScheduler, block_nodes={'Trainer', 'Transform', 'Pusher'}, collector=collector)) task_queue = tq.TaskQueue() # Enqueue some tasks. trainer_exec_task = _test_exec_node_task('Trainer', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(trainer_exec_task) task_queue.enqueue(_test_cancel_node_task('Trainer', 'test-pipeline')) with self._task_manager(task_queue) as task_manager: # Enqueue more tasks after task manager starts. transform_exec_task = _test_exec_node_task('Transform', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(transform_exec_task) evaluator_exec_task = _test_exec_node_task('Evaluator', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(evaluator_exec_task) task_queue.enqueue( _test_cancel_node_task('Transform', 'test-pipeline')) pusher_exec_task = _test_exec_node_task('Pusher', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(pusher_exec_task) task_queue.enqueue( _test_cancel_node_task('Pusher', 'test-pipeline', pause=True)) self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Ensure that all exec and cancellation tasks were processed correctly. self.assertCountEqual([ trainer_exec_task, transform_exec_task, evaluator_exec_task, pusher_exec_task, ], collector.scheduled_tasks) self.assertCountEqual([ trainer_exec_task, transform_exec_task, pusher_exec_task, ], collector.cancelled_tasks) result_ok = ts.TaskSchedulerResult(status=status_lib.Status( code=status_lib.Code.OK, message='_FakeTaskScheduler result')) result_cancelled = ts.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.CANCELLED, message='_FakeTaskScheduler result')) mock_publish.assert_has_calls([ mock.call(mlmd_handle=mock.ANY, task=trainer_exec_task, result=result_cancelled), mock.call(mlmd_handle=mock.ANY, task=transform_exec_task, result=result_cancelled), mock.call(mlmd_handle=mock.ANY, task=evaluator_exec_task, result=result_ok), ], any_order=True) # It is expected that publish is not called for Pusher because it was # cancelled with pause=True so there must be only 3 calls. self.assertLen(mock_publish.mock_calls, 3)