def stop_pipeline(mlmd_handle: metadata.Metadata, pipeline_uid: task_lib.PipelineUid, timeout_secs: Optional[float] = None) -> None: """Stops a pipeline. Initiates a pipeline stop operation and waits for the pipeline execution to be gracefully stopped in the orchestration loop. Args: mlmd_handle: A handle to the MLMD db. pipeline_uid: Uid of the pipeline to be stopped. timeout_secs: Amount of time in seconds to wait for pipeline to stop. If `None`, waits indefinitely. Raises: status_lib.StatusNotOkError: Failure to initiate pipeline stop. """ logging.info('Received request to stop pipeline; pipeline uid: %s', pipeline_uid) with _PIPELINE_OPS_LOCK: with pstate.PipelineState.load(mlmd_handle, pipeline_uid) as pipeline_state: pipeline_state.initiate_stop( status_lib.Status(code=status_lib.Code.CANCELLED, message='Cancellation requested by client.')) logging.info('Waiting for pipeline to be stopped; pipeline uid: %s', pipeline_uid) _wait_for_inactivation(mlmd_handle, pipeline_state.execution_id, timeout_secs=timeout_secs) logging.info('Done waiting for pipeline to be stopped; pipeline uid: %s', pipeline_uid)
def test_stop_pipeline_non_existent_or_inactive(self, pipeline): with self._mlmd_connection as m: # Stop pipeline without creating one. with self.assertRaises( status_lib.StatusNotOkError) as exception_context: pipeline_ops.stop_pipeline( m, task_lib.PipelineUid.from_pipeline(pipeline)) self.assertEqual(status_lib.Code.NOT_FOUND, exception_context.exception.code) # Initiate pipeline start and mark it completed. pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: pipeline_state.initiate_stop( status_lib.Status(code=status_lib.Code.OK)) pipeline_state.set_pipeline_execution_state( metadata_store_pb2.Execution.COMPLETE) # Try to initiate stop again. with self.assertRaises( status_lib.StatusNotOkError) as exception_context: pipeline_ops.stop_pipeline(m, pipeline_uid) self.assertEqual(status_lib.Code.NOT_FOUND, exception_context.exception.code)
def stop_pipeline( mlmd_handle: metadata.Metadata, pipeline_uid: task_lib.PipelineUid, timeout_secs: float = DEFAULT_WAIT_FOR_INACTIVATION_TIMEOUT_SECS ) -> None: """Stops a pipeline. Initiates a pipeline stop operation and waits for the pipeline execution to be gracefully stopped in the orchestration loop. Args: mlmd_handle: A handle to the MLMD db. pipeline_uid: Uid of the pipeline to be stopped. timeout_secs: Amount of time in seconds to wait for pipeline to stop. Raises: status_lib.StatusNotOkError: Failure to initiate pipeline stop. """ with _PIPELINE_OPS_LOCK: with pstate.PipelineState.load(mlmd_handle, pipeline_uid) as pipeline_state: pipeline_state.initiate_stop( status_lib.Status(code=status_lib.Code.CANCELLED, message='Cancellation requested by client.')) _wait_for_inactivation(mlmd_handle, pipeline_state.execution_id, timeout_secs=timeout_secs)
def schedule(self) -> task_scheduler.TaskSchedulerResult: def _as_dict(proto_map) -> Dict[str, types.Property]: return {k: data_types_utils.get_value(v) for k, v in proto_map.items()} task = typing.cast(task_lib.ExecNodeTask, self.task) pipeline_node = task.get_pipeline_node() output_spec = pipeline_node.outputs.outputs[importer.IMPORT_RESULT_KEY] properties = _as_dict(output_spec.artifact_spec.additional_properties) custom_properties = _as_dict( output_spec.artifact_spec.additional_custom_properties) output_artifacts = importer.generate_output_dict( metadata_handler=self.mlmd_handle, uri=str(task.exec_properties[importer.SOURCE_URI_KEY]), properties=properties, custom_properties=custom_properties, reimport=bool(task.exec_properties[importer.REIMPORT_OPTION_KEY]), output_artifact_class=types.Artifact( output_spec.artifact_spec.type).type, mlmd_artifact_type=output_spec.artifact_spec.type) return task_scheduler.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output_artifacts=output_artifacts)
def schedule(self) -> task_scheduler.TaskSchedulerResult: while not self._cancel.wait(_POLLING_INTERVAL_SECS): with mlmd_state.mlmd_execution_atomic_op( mlmd_handle=self.mlmd_handle, execution_id=self.task.execution_id) as execution: node_state_mlmd_value = execution.custom_properties.get( NODE_STATE_PROPERTY_KEY) node_state = ManualNodeState.from_mlmd_value( node_state_mlmd_value) if node_state.state == ManualNodeState.COMPLETED: return task_scheduler.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output=task_scheduler.ExecutorNodeOutput()) return task_scheduler.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.CANCELLED), output=task_scheduler.ExecutorNodeOutput())
def schedule(self): logging.info('_FakeTaskScheduler: scheduling task: %s', self.task) self._collector.add_scheduled_task(self.task) if self.task.node_uid.node_id in self._block_nodes: self._stop_event.wait() return ts.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), executor_output=execution_result_pb2.ExecutorOutput())
def _abort_task(self, error_msg: str) -> task_lib.FinalizePipelineTask: """Returns task to abort pipeline execution.""" error_msg = (f'Aborting pipeline execution due to node execution failure; ' f'error: {error_msg}') logging.error(error_msg) return task_lib.FinalizePipelineTask( pipeline_uid=self._pipeline_uid, status=status_lib.Status( code=status_lib.Code.ABORTED, message=error_msg))
def schedule(self): logging.info('_FakeTaskScheduler: scheduling task: %s', self.task) self._collector.add_scheduled_task(self.task) if self.task.node_uid.node_id in self._block_nodes: self._cancel.wait() code = status_lib.Code.CANCELLED else: code = status_lib.Code.OK return ts.TaskSchedulerResult(status=status_lib.Status( code=code, message='_FakeTaskScheduler result'))
def stop_node( mlmd_handle: metadata.Metadata, node_uid: task_lib.NodeUid, timeout_secs: float = DEFAULT_WAIT_FOR_INACTIVATION_TIMEOUT_SECS ) -> None: """Stops a node in an async pipeline. Initiates a node stop operation and waits for the node execution to become inactive. Args: mlmd_handle: A handle to the MLMD db. node_uid: Uid of the node to be stopped. timeout_secs: Amount of time in seconds to wait for node to stop. Raises: status_lib.StatusNotOkError: Failure to stop the node. """ with _PIPELINE_OPS_LOCK: with pstate.PipelineState.load( mlmd_handle, node_uid.pipeline_uid) as pipeline_state: nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline) filtered_nodes = [ n for n in nodes if n.node_info.id == node_uid.node_id ] if len(filtered_nodes) != 1: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message= (f'`stop_node` operation failed, unable to find node to stop: ' f'{node_uid}')) node = filtered_nodes[0] pipeline_state.initiate_node_stop( node_uid, status_lib.Status(code=status_lib.Code.CANCELLED, message='Cancellation requested by client.')) executions = task_gen_utils.get_executions(mlmd_handle, node) active_executions = [ e for e in executions if execution_lib.is_execution_active(e) ] if not active_executions: # If there are no active executions, we're done. return if len(active_executions) > 1: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message= (f'Unexpected multiple active executions for node: {node_uid}' )) _wait_for_inactivation(mlmd_handle, active_executions[0], timeout_secs=timeout_secs)
def _abort_node_task( self, node_uid: task_lib.NodeUid) -> task_lib.FinalizeNodeTask: """Returns task to abort the node execution.""" logging.error( 'Required service node not running or healthy, node uid: %s', node_uid) return task_lib.FinalizeNodeTask( node_uid=node_uid, status=status_lib.Status( code=status_lib.Code.ABORTED, message=(f'Aborting node execution as the associated service ' f'job is not running or healthy; problematic node ' f'uid: {node_uid}')))
def stop_initiated_reason(self) -> Optional[status_lib.Status]: """Returns status object if stop initiated, `None` otherwise.""" custom_properties = self.execution.custom_properties if _get_metadata_value(custom_properties.get(_STOP_INITIATED)) == 1: code = _get_metadata_value( custom_properties.get(_PIPELINE_STATUS_CODE)) if code is None: code = status_lib.Code.UNKNOWN message = _get_metadata_value( custom_properties.get(_PIPELINE_STATUS_MSG)) return status_lib.Status(code=code, message=message) else: return None
def schedule(self) -> ts.TaskSchedulerResult: logging.info('Processing ExecNodeTask: %s', self.task) executor_output = execution_result_pb2.ExecutorOutput() executor_output.execution_result.code = status_lib.Code.OK for key, artifacts in self.task.output_artifacts.items(): for artifact in artifacts: executor_output.output_artifacts[key].artifacts.add().CopyFrom( artifact.mlmd_artifact) result = ts.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output=ts.ExecutorNodeOutput(executor_output=executor_output)) logging.info('Result: %s', result) return result
def test_initiate_node_start_stop(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') node_uid = task_lib.NodeUid( node_id='Trainer', pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline)) with pstate.PipelineState.new(m, pipeline) as pipeline_state: with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STARTING) node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STARTING, node_state.state) # Reload from MLMD and verify node is started. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STARTING, node_state.state) # Set node state to STOPPING. status = status_lib.Status(code=status_lib.Code.ABORTED, message='foo bar') with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STOPPING, status) node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STOPPING, node_state.state) self.assertEqual(status, node_state.status) # Reload from MLMD and verify node is stopped. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STOPPING, node_state.state) self.assertEqual(status, node_state.status) # Set node state to STARTED. with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STARTED) node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STARTED, node_state.state) # Reload from MLMD and verify node is started. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: node_state = pipeline_state.get_node_state(node_uid) self.assertEqual(pstate.NodeState.STARTED, node_state.state)
def test_pipeline_view_get_node_run_states(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator' pipeline.nodes.add().pipeline_node.node_info.id = 'Pusher' eg_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen') transform_node_uid = task_lib.NodeUid(pipeline_uid, 'Transform') trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer') evaluator_node_uid = task_lib.NodeUid(pipeline_uid, 'Evaluator') with pstate.PipelineState.new(m, pipeline) as pipeline_state: with pipeline_state.node_state_update_context( eg_node_uid) as node_state: node_state.update(pstate.NodeState.RUNNING) with pipeline_state.node_state_update_context( transform_node_uid) as node_state: node_state.update(pstate.NodeState.STARTING) with pipeline_state.node_state_update_context( trainer_node_uid) as node_state: node_state.update(pstate.NodeState.STARTED) with pipeline_state.node_state_update_context( evaluator_node_uid) as node_state: node_state.update( pstate.NodeState.FAILED, status_lib.Status(code=status_lib.Code.ABORTED, message='foobar error')) [view] = pstate.PipelineView.load_all( m, task_lib.PipelineUid.from_pipeline(pipeline)) run_states_dict = view.get_node_run_states() self.assertEqual( run_state_pb2.RunState(state=run_state_pb2.RunState.RUNNING), run_states_dict['ExampleGen']) self.assertEqual( run_state_pb2.RunState(state=run_state_pb2.RunState.UNKNOWN), run_states_dict['Transform']) self.assertEqual( run_state_pb2.RunState(state=run_state_pb2.RunState.READY), run_states_dict['Trainer']) self.assertEqual( run_state_pb2.RunState( state=run_state_pb2.RunState.FAILED, status_code=run_state_pb2.RunState.StatusCodeValue( value=status_lib.Code.ABORTED), status_msg='foobar error'), run_states_dict['Evaluator']) self.assertEqual( run_state_pb2.RunState(state=run_state_pb2.RunState.READY), run_states_dict['Pusher'])
def test_stop_initiation(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') with pstate.PipelineState.new(m, pipeline) as pipeline_state: self.assertIsNone(pipeline_state.stop_initiated_reason()) status = status_lib.Status( code=status_lib.Code.CANCELLED, message='foo bar') pipeline_state.initiate_stop(status) self.assertEqual(status, pipeline_state.stop_initiated_reason()) # Reload from MLMD and verify. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state: self.assertEqual(status, pipeline_state.stop_initiated_reason())
def test_node_state_update(self): node_state = pstate.NodeState() self.assertEqual(pstate.NodeState.STARTED, node_state.state) self.assertIsNone(node_state.status) status = status_lib.Status(code=status_lib.Code.CANCELLED, message='foobar') node_state.update(pstate.NodeState.STOPPING, status) self.assertEqual(pstate.NodeState.STOPPING, node_state.state) self.assertEqual(status, node_state.status) node_state.update(pstate.NodeState.STARTING) self.assertEqual(pstate.NodeState.STARTING, node_state.state) self.assertIsNone(node_state.status)
def _publish_execution_results(mlmd_handle: metadata.Metadata, task: task_lib.ExecNodeTask, result: ts.TaskSchedulerResult) -> None: """Publishes execution results to MLMD.""" def _update_state(status: status_lib.Status) -> None: assert status.code != status_lib.Code.OK if status.code == status_lib.Code.CANCELLED: logging.info('Cancelling execution (id: %s); task id: %s; status: %s', task.execution_id, task.task_id, status) execution_state = metadata_store_pb2.Execution.CANCELED else: logging.info( 'Aborting execution (id: %s) due to error (code: %s); task id: %s', task.execution_id, status.code, task.task_id) execution_state = metadata_store_pb2.Execution.FAILED _update_execution_state_in_mlmd(mlmd_handle, task.execution_id, execution_state, status.message) pipeline_state.record_state_change_time() if result.status.code != status_lib.Code.OK: _update_state(result.status) return # TODO(b/182316162): Unify publisher handing so that post-execution artifact # logic is more cleanly handled. outputs_utils.tag_output_artifacts_with_version(task.output_artifacts) publish_params = dict(output_artifacts=task.output_artifacts) if result.output_artifacts is not None: # TODO(b/182316162): Unify publisher handing so that post-execution artifact # logic is more cleanly handled. outputs_utils.tag_output_artifacts_with_version(result.output_artifacts) publish_params['output_artifacts'] = result.output_artifacts elif result.executor_output is not None: if result.executor_output.execution_result.code != status_lib.Code.OK: _update_state( status_lib.Status( code=result.executor_output.execution_result.code, message=result.executor_output.execution_result.result_message)) return # TODO(b/182316162): Unify publisher handing so that post-execution artifact # logic is more cleanly handled. outputs_utils.tag_executor_output_with_version(result.executor_output) publish_params['executor_output'] = result.executor_output execution_publish_utils.publish_succeeded_execution(mlmd_handle, task.execution_id, task.contexts, **publish_params) pipeline_state.record_state_change_time()
def test_exceptions_are_surfaced(self, mock_publish): def _publish(**kwargs): task = kwargs['task'] assert task_lib.is_exec_node_task(task) if task.node_uid.node_id == 'Transform': raise ValueError('test error') return mock.DEFAULT mock_publish.side_effect = _publish collector = _Collector() # Register a fake task scheduler. ts.TaskSchedulerRegistry.register( self._type_url, functools.partial(_FakeTaskScheduler, block_nodes={}, collector=collector)) task_queue = tq.TaskQueue() with self._task_manager(task_queue) as task_manager: transform_task = _test_exec_node_task('Transform', 'test-pipeline', pipeline=self._pipeline) trainer_task = _test_exec_node_task('Trainer', 'test-pipeline', pipeline=self._pipeline) task_queue.enqueue(transform_task) task_queue.enqueue(trainer_task) self.assertTrue(task_manager.done()) exception = task_manager.exception() self.assertIsNotNone(exception) self.assertIsInstance(exception, tm.TasksProcessingError) self.assertLen(exception.errors, 1) self.assertEqual('test error', str(exception.errors[0])) self.assertCountEqual([transform_task, trainer_task], collector.scheduled_tasks) result_ok = ts.TaskSchedulerResult(status=status_lib.Status( code=status_lib.Code.OK, message='_FakeTaskScheduler result')) mock_publish.assert_has_calls([ mock.call( mlmd_handle=mock.ANY, task=transform_task, result=result_ok), mock.call( mlmd_handle=mock.ANY, task=trainer_task, result=result_ok), ], any_order=True)
def test_successful_execution_resulting_in_output_artifacts(self): # Register a fake task scheduler that returns a successful execution result # and `OK` task scheduler status. self._register_task_scheduler( ts.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output_artifacts=self._task.output_artifacts)) task_manager = self._run_task_manager() self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Check that the task was processed and MLMD execution marked successful. self.assertTrue(self._task_queue.is_empty()) execution = self._get_execution() self.assertEqual(metadata_store_pb2.Execution.COMPLETE, execution.last_known_state)
def stop_node(mlmd_handle: metadata.Metadata, node_uid: task_lib.NodeUid, timeout_secs: Optional[float] = None) -> None: """Stops a node. Initiates a node stop operation and waits for the node execution to become inactive. Args: mlmd_handle: A handle to the MLMD db. node_uid: Uid of the node to be stopped. timeout_secs: Amount of time in seconds to wait for node to stop. If `None`, waits indefinitely. Raises: status_lib.StatusNotOkError: Failure to stop the node. """ logging.info('Received request to stop node; node uid: %s', node_uid) with _PIPELINE_OPS_LOCK: with pstate.PipelineState.load( mlmd_handle, node_uid.pipeline_uid) as pipeline_state: nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline) filtered_nodes = [ n for n in nodes if n.node_info.id == node_uid.node_id ] if len(filtered_nodes) != 1: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message= (f'`stop_node` operation failed, unable to find node to stop: ' f'{node_uid}')) node = filtered_nodes[0] with pipeline_state.node_state_update_context( node_uid) as node_state: if node_state.is_stoppable(): node_state.update( pstate.NodeState.STOPPING, status_lib.Status( code=status_lib.Code.CANCELLED, message='Cancellation requested by client.')) # Wait until the node is stopped or time out. _wait_for_node_inactivation(pipeline_state, node_uid, timeout_secs=timeout_secs)
def test_initiate_node_start_stop(self): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') node_uid = task_lib.NodeUid( node_id='Trainer', pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline)) with pstate.PipelineState.new(m, pipeline) as pipeline_state: pipeline_state.initiate_node_start(node_uid) self.assertIsNone( pipeline_state.node_stop_initiated_reason(node_uid)) # Reload from MLMD and verify node is started. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: self.assertIsNone( pipeline_state.node_stop_initiated_reason(node_uid)) # Stop the node. status = status_lib.Status(code=status_lib.Code.ABORTED, message='foo bar') pipeline_state.initiate_node_stop(node_uid, status) self.assertEqual( status, pipeline_state.node_stop_initiated_reason(node_uid)) # Reload from MLMD and verify node is stopped. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: self.assertEqual( status, pipeline_state.node_stop_initiated_reason(node_uid)) # Restart node. pipeline_state.initiate_node_start(node_uid) self.assertIsNone( pipeline_state.node_stop_initiated_reason(node_uid)) # Reload from MLMD and verify node is started. with pstate.PipelineState.load( m, task_lib.PipelineUid.from_pipeline( pipeline)) as pipeline_state: self.assertIsNone( pipeline_state.node_stop_initiated_reason(node_uid))
def test_scheduler_failure(self): # Register a fake task scheduler that returns a failure status. self._register_task_scheduler( ts.TaskSchedulerResult(status=status_lib.Status( code=status_lib.Code.ABORTED, message='foobar error'), executor_output=None)) task_manager = self._run_task_manager() self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Check that the task was processed and MLMD execution marked failed. self.assertTrue(self._task_queue.is_empty()) execution = self._get_execution() self.assertEqual(metadata_store_pb2.Execution.FAILED, execution.last_known_state) self.assertEqual( 'foobar error', data_types_utils.get_metadata_value(execution.custom_properties[ constants.EXECUTION_ERROR_MSG_KEY]))
def test_task_generation_when_node_stopped(self, stop_transform): """Tests stopped nodes are ignored when generating tasks.""" # Simulate that ExampleGen has already completed successfully. test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) # Generate once. num_initial_executions = 1 if stop_transform: num_tasks_generated = 1 num_new_executions = 0 num_active_executions = 0 with self._mlmd_connection as m: pipeline_state = test_utils.get_or_create_pipeline_state( m, self._pipeline) with pipeline_state: with pipeline_state.node_state_update_context( task_lib.NodeUid.from_pipeline_node( self._pipeline, self._transform)) as node_state: node_state.update(pstate.NodeState.STOPPING, status_lib.Status(code=status_lib.Code.CANCELLED)) else: num_tasks_generated = 3 num_new_executions = 1 num_active_executions = 1 tasks = self._generate_and_test( True, num_initial_executions=num_initial_executions, num_tasks_generated=num_tasks_generated, num_new_executions=num_new_executions, num_active_executions=num_active_executions) self.assertLen(tasks, num_tasks_generated) if stop_transform: self.assertTrue(task_lib.is_update_node_state_task(tasks[0])) self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state) else: self.assertTrue(task_lib.is_update_node_state_task(tasks[0])) self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state) self.assertTrue(task_lib.is_update_node_state_task(tasks[1])) self.assertEqual(pstate.NodeState.RUNNING, tasks[1].state) self.assertTrue(task_lib.is_exec_node_task(tasks[2]))
def node_stop_initiated_reason( self, node_uid: task_lib.NodeUid) -> Optional[status_lib.Status]: """Returns status object if node stop initiated, `None` otherwise.""" if node_uid.pipeline_uid != self.pipeline_uid: raise RuntimeError( f'Node given by uid {node_uid} does not belong to pipeline given ' f'by uid {self.pipeline_uid}') custom_properties = self.execution.custom_properties if _get_metadata_value( custom_properties.get( _node_stop_initiated_property(node_uid))) == 1: code = _get_metadata_value( custom_properties.get(_node_status_code_property(node_uid))) if code is None: code = status_lib.Code.UNKNOWN message = _get_metadata_value( custom_properties.get(_node_status_msg_property(node_uid))) return status_lib.Status(code=code, message=message) else: return None
def _publish_execution_results(mlmd_handle: metadata.Metadata, task: task_lib.ExecNodeTask, result: ts.TaskSchedulerResult) -> None: """Publishes execution results to MLMD.""" def _update_state(status: status_lib.Status) -> None: assert status.code != status_lib.Code.OK if status.code == status_lib.Code.CANCELLED: logging.info('Cancelling execution (id: %s); task id: %s; status: %s', task.execution.id, task.task_id, status) execution_state = metadata_store_pb2.Execution.CANCELED else: logging.info( 'Aborting execution (id: %s) due to error (code: %s); task id: %s', task.execution.id, status.code, task.task_id) execution_state = metadata_store_pb2.Execution.FAILED _update_execution_state_in_mlmd(mlmd_handle, task.execution, execution_state, status.message) if result.status.code != status_lib.Code.OK: _update_state(result.status) return publish_params = dict(output_artifacts=task.output_artifacts) if result.output_artifacts is not None: publish_params['output_artifacts'] = result.output_artifacts elif result.executor_output is not None: if result.executor_output.execution_result.code != status_lib.Code.OK: _update_state( status_lib.Status( code=result.executor_output.execution_result.code, message=result.executor_output.execution_result.result_message)) return publish_params['executor_output'] = result.executor_output execution_publish_utils.publish_succeeded_execution(mlmd_handle, task.execution.id, task.contexts, **publish_params)
def test_handling_finalize_node_task(self, task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1') pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer' pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED, message='foo bar') task_gen.return_value.generate.side_effect = [ [ test_utils.create_exec_node_task( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Transform')), task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid( pipeline_uid=pipeline_uid, node_id='Trainer'), status=finalize_reason) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) task_gen.return_value.generate.assert_called_once() task = task_queue.dequeue() task_queue.task_done(task) self.assertTrue(task_lib.is_exec_node_task(task)) self.assertEqual( test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid) # Load pipeline state and verify node stop initiation. with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual( finalize_reason, pipeline_state.node_stop_initiated_reason( task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer')))
def test_successful_execution_resulting_in_output_artifacts(self): # Register a fake task scheduler that returns a successful execution result # and `OK` task scheduler status. self._register_task_scheduler( ts.TaskSchedulerResult( status=status_lib.Status(code=status_lib.Code.OK), output=ts.ImporterNodeOutput( output_artifacts=self._task.output_artifacts))) task_manager = self._run_task_manager() self.assertTrue(task_manager.done()) self.assertIsNone(task_manager.exception()) # Check that the task was processed and MLMD execution marked successful. self.assertTrue(self._task_queue.is_empty()) execution = self._get_execution() self.assertEqual(metadata_store_pb2.Execution.COMPLETE, execution.last_known_state) # Check that stateful working dir and tmp_dir are removed. self.assertFalse(os.path.exists(self._task.stateful_working_dir)) self.assertFalse(os.path.exists(self._task.tmp_dir)) # Output artifact URI remains as execution was successful. self.assertTrue(os.path.exists(self._output_artifact_uri))
def _process_exec_node_task(self, scheduler: ts.TaskScheduler, task: task_lib.ExecNodeTask) -> None: """Processes an `ExecNodeTask` using the given task scheduler.""" # This is a blocking call to the scheduler which can take a long time to # complete for some types of task schedulers. The scheduler is expected to # handle any internal errors gracefully and return the result with an error # status. But in case the scheduler raises an exception, it is considered # a failed execution and MLMD is updated accordingly. try: result = scheduler.schedule() except Exception as e: # pylint: disable=broad-except logging.exception('Exception raised by task scheduler; node uid: %s', task.node_uid) result = ts.TaskSchedulerResult( status=status_lib.Status( code=status_lib.Code.ABORTED, message=str(e))) logging.info('For ExecNodeTask id: %s, task-scheduler result status: %s', task.task_id, result.status) _publish_execution_results( mlmd_handle=self._mlmd_handle, task=task, result=result) with self._tm_lock: del self._scheduler_by_node_uid[task.node_uid] self._task_queue.task_done(task)
def test_handling_finalize_pipeline_task(self, task_gen): with self._mlmd_connection as m: pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC) pipeline_ops.initiate_pipeline_start(m, pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED, message='foo bar') task_gen.return_value.generate.side_effect = [ [ task_lib.FinalizePipelineTask(pipeline_uid=pipeline_uid, status=finalize_reason) ], ] task_queue = tq.TaskQueue() pipeline_ops.orchestrate(m, task_queue, service_jobs.DummyServiceJobManager()) task_gen.return_value.generate.assert_called_once() self.assertTrue(task_queue.is_empty()) # Load pipeline state and verify stop initiation. with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: self.assertEqual(finalize_reason, pipeline_state.stop_initiated_reason())
def _publish_execution_results(mlmd_handle: metadata.Metadata, task: task_lib.ExecNodeTask, result: ts.TaskSchedulerResult) -> None: """Publishes execution results to MLMD.""" def _update_state(status: status_lib.Status) -> None: assert status.code != status_lib.Code.OK _remove_output_dirs(task, result) _remove_task_dirs(task) if status.code == status_lib.Code.CANCELLED: logging.info( 'Cancelling execution (id: %s); task id: %s; status: %s', task.execution_id, task.task_id, status) execution_state = metadata_store_pb2.Execution.CANCELED else: logging.info( 'Aborting execution (id: %s) due to error (code: %s); task id: %s', task.execution_id, status.code, task.task_id) execution_state = metadata_store_pb2.Execution.FAILED _update_execution_state_in_mlmd(mlmd_handle, task.execution_id, execution_state, status.message) pipeline_state.record_state_change_time() if result.status.code != status_lib.Code.OK: _update_state(result.status) return # TODO(b/182316162): Unify publisher handing so that post-execution artifact # logic is more cleanly handled. outputs_utils.tag_output_artifacts_with_version(task.output_artifacts) if isinstance(result.output, ts.ExecutorNodeOutput): executor_output = result.output.executor_output if executor_output is not None: if executor_output.execution_result.code != status_lib.Code.OK: _update_state( status_lib.Status( code=executor_output.execution_result.code, message=executor_output.execution_result.result_message )) return # TODO(b/182316162): Unify publisher handing so that post-execution # artifact logic is more cleanly handled. outputs_utils.tag_executor_output_with_version(executor_output) _remove_task_dirs(task) execution_publish_utils.publish_succeeded_execution( mlmd_handle, execution_id=task.execution_id, contexts=task.contexts, output_artifacts=task.output_artifacts, executor_output=executor_output) elif isinstance(result.output, ts.ImporterNodeOutput): output_artifacts = result.output.output_artifacts # TODO(b/182316162): Unify publisher handing so that post-execution artifact # logic is more cleanly handled. outputs_utils.tag_output_artifacts_with_version(output_artifacts) _remove_task_dirs(task) execution_publish_utils.publish_succeeded_execution( mlmd_handle, execution_id=task.execution_id, contexts=task.contexts, output_artifacts=output_artifacts) elif isinstance(result.output, ts.ResolverNodeOutput): resolved_input_artifacts = result.output.resolved_input_artifacts execution_publish_utils.publish_internal_execution( mlmd_handle, execution_id=task.execution_id, contexts=task.contexts, output_artifacts=resolved_input_artifacts) else: raise TypeError(f'Unable to process task scheduler result: {result}') pipeline_state.record_state_change_time()