예제 #1
0
def stop_pipeline(mlmd_handle: metadata.Metadata,
                  pipeline_uid: task_lib.PipelineUid,
                  timeout_secs: Optional[float] = None) -> None:
    """Stops a pipeline.

  Initiates a pipeline stop operation and waits for the pipeline execution to be
  gracefully stopped in the orchestration loop.

  Args:
    mlmd_handle: A handle to the MLMD db.
    pipeline_uid: Uid of the pipeline to be stopped.
    timeout_secs: Amount of time in seconds to wait for pipeline to stop. If
      `None`, waits indefinitely.

  Raises:
    status_lib.StatusNotOkError: Failure to initiate pipeline stop.
  """
    logging.info('Received request to stop pipeline; pipeline uid: %s',
                 pipeline_uid)
    with _PIPELINE_OPS_LOCK:
        with pstate.PipelineState.load(mlmd_handle,
                                       pipeline_uid) as pipeline_state:
            pipeline_state.initiate_stop(
                status_lib.Status(code=status_lib.Code.CANCELLED,
                                  message='Cancellation requested by client.'))
    logging.info('Waiting for pipeline to be stopped; pipeline uid: %s',
                 pipeline_uid)
    _wait_for_inactivation(mlmd_handle,
                           pipeline_state.execution_id,
                           timeout_secs=timeout_secs)
    logging.info('Done waiting for pipeline to be stopped; pipeline uid: %s',
                 pipeline_uid)
예제 #2
0
    def test_stop_pipeline_non_existent_or_inactive(self, pipeline):
        with self._mlmd_connection as m:
            # Stop pipeline without creating one.
            with self.assertRaises(
                    status_lib.StatusNotOkError) as exception_context:
                pipeline_ops.stop_pipeline(
                    m, task_lib.PipelineUid.from_pipeline(pipeline))
            self.assertEqual(status_lib.Code.NOT_FOUND,
                             exception_context.exception.code)

            # Initiate pipeline start and mark it completed.
            pipeline_ops.initiate_pipeline_start(m, pipeline)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
                pipeline_state.initiate_stop(
                    status_lib.Status(code=status_lib.Code.OK))
                pipeline_state.set_pipeline_execution_state(
                    metadata_store_pb2.Execution.COMPLETE)

            # Try to initiate stop again.
            with self.assertRaises(
                    status_lib.StatusNotOkError) as exception_context:
                pipeline_ops.stop_pipeline(m, pipeline_uid)
            self.assertEqual(status_lib.Code.NOT_FOUND,
                             exception_context.exception.code)
예제 #3
0
def stop_pipeline(
        mlmd_handle: metadata.Metadata,
        pipeline_uid: task_lib.PipelineUid,
        timeout_secs: float = DEFAULT_WAIT_FOR_INACTIVATION_TIMEOUT_SECS
) -> None:
    """Stops a pipeline.

  Initiates a pipeline stop operation and waits for the pipeline execution to be
  gracefully stopped in the orchestration loop.

  Args:
    mlmd_handle: A handle to the MLMD db.
    pipeline_uid: Uid of the pipeline to be stopped.
    timeout_secs: Amount of time in seconds to wait for pipeline to stop.

  Raises:
    status_lib.StatusNotOkError: Failure to initiate pipeline stop.
  """
    with _PIPELINE_OPS_LOCK:
        with pstate.PipelineState.load(mlmd_handle,
                                       pipeline_uid) as pipeline_state:
            pipeline_state.initiate_stop(
                status_lib.Status(code=status_lib.Code.CANCELLED,
                                  message='Cancellation requested by client.'))
    _wait_for_inactivation(mlmd_handle,
                           pipeline_state.execution_id,
                           timeout_secs=timeout_secs)
예제 #4
0
  def schedule(self) -> task_scheduler.TaskSchedulerResult:

    def _as_dict(proto_map) -> Dict[str, types.Property]:
      return {k: data_types_utils.get_value(v) for k, v in proto_map.items()}

    task = typing.cast(task_lib.ExecNodeTask, self.task)
    pipeline_node = task.get_pipeline_node()
    output_spec = pipeline_node.outputs.outputs[importer.IMPORT_RESULT_KEY]
    properties = _as_dict(output_spec.artifact_spec.additional_properties)
    custom_properties = _as_dict(
        output_spec.artifact_spec.additional_custom_properties)

    output_artifacts = importer.generate_output_dict(
        metadata_handler=self.mlmd_handle,
        uri=str(task.exec_properties[importer.SOURCE_URI_KEY]),
        properties=properties,
        custom_properties=custom_properties,
        reimport=bool(task.exec_properties[importer.REIMPORT_OPTION_KEY]),
        output_artifact_class=types.Artifact(
            output_spec.artifact_spec.type).type,
        mlmd_artifact_type=output_spec.artifact_spec.type)

    return task_scheduler.TaskSchedulerResult(
        status=status_lib.Status(code=status_lib.Code.OK),
        output_artifacts=output_artifacts)
예제 #5
0
    def schedule(self) -> task_scheduler.TaskSchedulerResult:
        while not self._cancel.wait(_POLLING_INTERVAL_SECS):
            with mlmd_state.mlmd_execution_atomic_op(
                    mlmd_handle=self.mlmd_handle,
                    execution_id=self.task.execution_id) as execution:
                node_state_mlmd_value = execution.custom_properties.get(
                    NODE_STATE_PROPERTY_KEY)
                node_state = ManualNodeState.from_mlmd_value(
                    node_state_mlmd_value)
                if node_state.state == ManualNodeState.COMPLETED:
                    return task_scheduler.TaskSchedulerResult(
                        status=status_lib.Status(code=status_lib.Code.OK),
                        output=task_scheduler.ExecutorNodeOutput())

        return task_scheduler.TaskSchedulerResult(
            status=status_lib.Status(code=status_lib.Code.CANCELLED),
            output=task_scheduler.ExecutorNodeOutput())
예제 #6
0
 def schedule(self):
     logging.info('_FakeTaskScheduler: scheduling task: %s', self.task)
     self._collector.add_scheduled_task(self.task)
     if self.task.node_uid.node_id in self._block_nodes:
         self._stop_event.wait()
     return ts.TaskSchedulerResult(
         status=status_lib.Status(code=status_lib.Code.OK),
         executor_output=execution_result_pb2.ExecutorOutput())
예제 #7
0
 def _abort_task(self, error_msg: str) -> task_lib.FinalizePipelineTask:
   """Returns task to abort pipeline execution."""
   error_msg = (f'Aborting pipeline execution due to node execution failure; '
                f'error: {error_msg}')
   logging.error(error_msg)
   return task_lib.FinalizePipelineTask(
       pipeline_uid=self._pipeline_uid,
       status=status_lib.Status(
           code=status_lib.Code.ABORTED, message=error_msg))
예제 #8
0
 def schedule(self):
     logging.info('_FakeTaskScheduler: scheduling task: %s', self.task)
     self._collector.add_scheduled_task(self.task)
     if self.task.node_uid.node_id in self._block_nodes:
         self._cancel.wait()
         code = status_lib.Code.CANCELLED
     else:
         code = status_lib.Code.OK
     return ts.TaskSchedulerResult(status=status_lib.Status(
         code=code, message='_FakeTaskScheduler result'))
예제 #9
0
def stop_node(
        mlmd_handle: metadata.Metadata,
        node_uid: task_lib.NodeUid,
        timeout_secs: float = DEFAULT_WAIT_FOR_INACTIVATION_TIMEOUT_SECS
) -> None:
    """Stops a node in an async pipeline.

  Initiates a node stop operation and waits for the node execution to become
  inactive.

  Args:
    mlmd_handle: A handle to the MLMD db.
    node_uid: Uid of the node to be stopped.
    timeout_secs: Amount of time in seconds to wait for node to stop.

  Raises:
    status_lib.StatusNotOkError: Failure to stop the node.
  """
    with _PIPELINE_OPS_LOCK:
        with pstate.PipelineState.load(
                mlmd_handle, node_uid.pipeline_uid) as pipeline_state:
            nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline)
            filtered_nodes = [
                n for n in nodes if n.node_info.id == node_uid.node_id
            ]
            if len(filtered_nodes) != 1:
                raise status_lib.StatusNotOkError(
                    code=status_lib.Code.INTERNAL,
                    message=
                    (f'`stop_node` operation failed, unable to find node to stop: '
                     f'{node_uid}'))
            node = filtered_nodes[0]
            pipeline_state.initiate_node_stop(
                node_uid,
                status_lib.Status(code=status_lib.Code.CANCELLED,
                                  message='Cancellation requested by client.'))

        executions = task_gen_utils.get_executions(mlmd_handle, node)
        active_executions = [
            e for e in executions if execution_lib.is_execution_active(e)
        ]
        if not active_executions:
            # If there are no active executions, we're done.
            return
        if len(active_executions) > 1:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.INTERNAL,
                message=
                (f'Unexpected multiple active executions for node: {node_uid}'
                 ))
    _wait_for_inactivation(mlmd_handle,
                           active_executions[0],
                           timeout_secs=timeout_secs)
예제 #10
0
 def _abort_node_task(
         self, node_uid: task_lib.NodeUid) -> task_lib.FinalizeNodeTask:
     """Returns task to abort the node execution."""
     logging.error(
         'Required service node not running or healthy, node uid: %s',
         node_uid)
     return task_lib.FinalizeNodeTask(
         node_uid=node_uid,
         status=status_lib.Status(
             code=status_lib.Code.ABORTED,
             message=(f'Aborting node execution as the associated service '
                      f'job is not running or healthy; problematic node '
                      f'uid: {node_uid}')))
예제 #11
0
 def stop_initiated_reason(self) -> Optional[status_lib.Status]:
     """Returns status object if stop initiated, `None` otherwise."""
     custom_properties = self.execution.custom_properties
     if _get_metadata_value(custom_properties.get(_STOP_INITIATED)) == 1:
         code = _get_metadata_value(
             custom_properties.get(_PIPELINE_STATUS_CODE))
         if code is None:
             code = status_lib.Code.UNKNOWN
         message = _get_metadata_value(
             custom_properties.get(_PIPELINE_STATUS_MSG))
         return status_lib.Status(code=code, message=message)
     else:
         return None
예제 #12
0
 def schedule(self) -> ts.TaskSchedulerResult:
     logging.info('Processing ExecNodeTask: %s', self.task)
     executor_output = execution_result_pb2.ExecutorOutput()
     executor_output.execution_result.code = status_lib.Code.OK
     for key, artifacts in self.task.output_artifacts.items():
         for artifact in artifacts:
             executor_output.output_artifacts[key].artifacts.add().CopyFrom(
                 artifact.mlmd_artifact)
     result = ts.TaskSchedulerResult(
         status=status_lib.Status(code=status_lib.Code.OK),
         output=ts.ExecutorNodeOutput(executor_output=executor_output))
     logging.info('Result: %s', result)
     return result
예제 #13
0
    def test_initiate_node_start_stop(self):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            node_uid = task_lib.NodeUid(
                node_id='Trainer',
                pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline))
            with pstate.PipelineState.new(m, pipeline) as pipeline_state:
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STARTING)
                node_state = pipeline_state.get_node_state(node_uid)
                self.assertEqual(pstate.NodeState.STARTING, node_state.state)

            # Reload from MLMD and verify node is started.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                node_state = pipeline_state.get_node_state(node_uid)
                self.assertEqual(pstate.NodeState.STARTING, node_state.state)

                # Set node state to STOPPING.
                status = status_lib.Status(code=status_lib.Code.ABORTED,
                                           message='foo bar')
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STOPPING, status)
                node_state = pipeline_state.get_node_state(node_uid)
                self.assertEqual(pstate.NodeState.STOPPING, node_state.state)
                self.assertEqual(status, node_state.status)

            # Reload from MLMD and verify node is stopped.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                node_state = pipeline_state.get_node_state(node_uid)
                self.assertEqual(pstate.NodeState.STOPPING, node_state.state)
                self.assertEqual(status, node_state.status)

                # Set node state to STARTED.
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STARTED)
                node_state = pipeline_state.get_node_state(node_uid)
                self.assertEqual(pstate.NodeState.STARTED, node_state.state)

            # Reload from MLMD and verify node is started.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                node_state = pipeline_state.get_node_state(node_uid)
                self.assertEqual(pstate.NodeState.STARTED, node_state.state)
예제 #14
0
    def test_pipeline_view_get_node_run_states(self):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Pusher'
            eg_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen')
            transform_node_uid = task_lib.NodeUid(pipeline_uid, 'Transform')
            trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer')
            evaluator_node_uid = task_lib.NodeUid(pipeline_uid, 'Evaluator')
            with pstate.PipelineState.new(m, pipeline) as pipeline_state:
                with pipeline_state.node_state_update_context(
                        eg_node_uid) as node_state:
                    node_state.update(pstate.NodeState.RUNNING)
                with pipeline_state.node_state_update_context(
                        transform_node_uid) as node_state:
                    node_state.update(pstate.NodeState.STARTING)
                with pipeline_state.node_state_update_context(
                        trainer_node_uid) as node_state:
                    node_state.update(pstate.NodeState.STARTED)
                with pipeline_state.node_state_update_context(
                        evaluator_node_uid) as node_state:
                    node_state.update(
                        pstate.NodeState.FAILED,
                        status_lib.Status(code=status_lib.Code.ABORTED,
                                          message='foobar error'))

            [view] = pstate.PipelineView.load_all(
                m, task_lib.PipelineUid.from_pipeline(pipeline))
            run_states_dict = view.get_node_run_states()
            self.assertEqual(
                run_state_pb2.RunState(state=run_state_pb2.RunState.RUNNING),
                run_states_dict['ExampleGen'])
            self.assertEqual(
                run_state_pb2.RunState(state=run_state_pb2.RunState.UNKNOWN),
                run_states_dict['Transform'])
            self.assertEqual(
                run_state_pb2.RunState(state=run_state_pb2.RunState.READY),
                run_states_dict['Trainer'])
            self.assertEqual(
                run_state_pb2.RunState(
                    state=run_state_pb2.RunState.FAILED,
                    status_code=run_state_pb2.RunState.StatusCodeValue(
                        value=status_lib.Code.ABORTED),
                    status_msg='foobar error'), run_states_dict['Evaluator'])
            self.assertEqual(
                run_state_pb2.RunState(state=run_state_pb2.RunState.READY),
                run_states_dict['Pusher'])
예제 #15
0
  def test_stop_initiation(self):
    with self._mlmd_connection as m:
      pipeline = _test_pipeline('pipeline1')
      with pstate.PipelineState.new(m, pipeline) as pipeline_state:
        self.assertIsNone(pipeline_state.stop_initiated_reason())
        status = status_lib.Status(
            code=status_lib.Code.CANCELLED, message='foo bar')
        pipeline_state.initiate_stop(status)
        self.assertEqual(status, pipeline_state.stop_initiated_reason())

      # Reload from MLMD and verify.
      with pstate.PipelineState.load(
          m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state:
        self.assertEqual(status, pipeline_state.stop_initiated_reason())
예제 #16
0
    def test_node_state_update(self):
        node_state = pstate.NodeState()
        self.assertEqual(pstate.NodeState.STARTED, node_state.state)
        self.assertIsNone(node_state.status)

        status = status_lib.Status(code=status_lib.Code.CANCELLED,
                                   message='foobar')
        node_state.update(pstate.NodeState.STOPPING, status)
        self.assertEqual(pstate.NodeState.STOPPING, node_state.state)
        self.assertEqual(status, node_state.status)

        node_state.update(pstate.NodeState.STARTING)
        self.assertEqual(pstate.NodeState.STARTING, node_state.state)
        self.assertIsNone(node_state.status)
예제 #17
0
파일: task_manager.py 프로젝트: htahir1/tfx
def _publish_execution_results(mlmd_handle: metadata.Metadata,
                               task: task_lib.ExecNodeTask,
                               result: ts.TaskSchedulerResult) -> None:
  """Publishes execution results to MLMD."""

  def _update_state(status: status_lib.Status) -> None:
    assert status.code != status_lib.Code.OK
    if status.code == status_lib.Code.CANCELLED:
      logging.info('Cancelling execution (id: %s); task id: %s; status: %s',
                   task.execution_id, task.task_id, status)
      execution_state = metadata_store_pb2.Execution.CANCELED
    else:
      logging.info(
          'Aborting execution (id: %s) due to error (code: %s); task id: %s',
          task.execution_id, status.code, task.task_id)
      execution_state = metadata_store_pb2.Execution.FAILED
    _update_execution_state_in_mlmd(mlmd_handle, task.execution_id,
                                    execution_state, status.message)
    pipeline_state.record_state_change_time()

  if result.status.code != status_lib.Code.OK:
    _update_state(result.status)
    return

  # TODO(b/182316162): Unify publisher handing so that post-execution artifact
  # logic is more cleanly handled.
  outputs_utils.tag_output_artifacts_with_version(task.output_artifacts)
  publish_params = dict(output_artifacts=task.output_artifacts)
  if result.output_artifacts is not None:
    # TODO(b/182316162): Unify publisher handing so that post-execution artifact
    # logic is more cleanly handled.
    outputs_utils.tag_output_artifacts_with_version(result.output_artifacts)
    publish_params['output_artifacts'] = result.output_artifacts
  elif result.executor_output is not None:
    if result.executor_output.execution_result.code != status_lib.Code.OK:
      _update_state(
          status_lib.Status(
              code=result.executor_output.execution_result.code,
              message=result.executor_output.execution_result.result_message))
      return
    # TODO(b/182316162): Unify publisher handing so that post-execution artifact
    # logic is more cleanly handled.
    outputs_utils.tag_executor_output_with_version(result.executor_output)
    publish_params['executor_output'] = result.executor_output

  execution_publish_utils.publish_succeeded_execution(mlmd_handle,
                                                      task.execution_id,
                                                      task.contexts,
                                                      **publish_params)
  pipeline_state.record_state_change_time()
예제 #18
0
    def test_exceptions_are_surfaced(self, mock_publish):
        def _publish(**kwargs):
            task = kwargs['task']
            assert task_lib.is_exec_node_task(task)
            if task.node_uid.node_id == 'Transform':
                raise ValueError('test error')
            return mock.DEFAULT

        mock_publish.side_effect = _publish

        collector = _Collector()

        # Register a fake task scheduler.
        ts.TaskSchedulerRegistry.register(
            self._type_url,
            functools.partial(_FakeTaskScheduler,
                              block_nodes={},
                              collector=collector))

        task_queue = tq.TaskQueue()

        with self._task_manager(task_queue) as task_manager:
            transform_task = _test_exec_node_task('Transform',
                                                  'test-pipeline',
                                                  pipeline=self._pipeline)
            trainer_task = _test_exec_node_task('Trainer',
                                                'test-pipeline',
                                                pipeline=self._pipeline)
            task_queue.enqueue(transform_task)
            task_queue.enqueue(trainer_task)

        self.assertTrue(task_manager.done())
        exception = task_manager.exception()
        self.assertIsNotNone(exception)
        self.assertIsInstance(exception, tm.TasksProcessingError)
        self.assertLen(exception.errors, 1)
        self.assertEqual('test error', str(exception.errors[0]))

        self.assertCountEqual([transform_task, trainer_task],
                              collector.scheduled_tasks)
        result_ok = ts.TaskSchedulerResult(status=status_lib.Status(
            code=status_lib.Code.OK, message='_FakeTaskScheduler result'))
        mock_publish.assert_has_calls([
            mock.call(
                mlmd_handle=mock.ANY, task=transform_task, result=result_ok),
            mock.call(
                mlmd_handle=mock.ANY, task=trainer_task, result=result_ok),
        ],
                                      any_order=True)
예제 #19
0
    def test_successful_execution_resulting_in_output_artifacts(self):
        # Register a fake task scheduler that returns a successful execution result
        # and `OK` task scheduler status.
        self._register_task_scheduler(
            ts.TaskSchedulerResult(
                status=status_lib.Status(code=status_lib.Code.OK),
                output_artifacts=self._task.output_artifacts))
        task_manager = self._run_task_manager()
        self.assertTrue(task_manager.done())
        self.assertIsNone(task_manager.exception())

        # Check that the task was processed and MLMD execution marked successful.
        self.assertTrue(self._task_queue.is_empty())
        execution = self._get_execution()
        self.assertEqual(metadata_store_pb2.Execution.COMPLETE,
                         execution.last_known_state)
예제 #20
0
def stop_node(mlmd_handle: metadata.Metadata,
              node_uid: task_lib.NodeUid,
              timeout_secs: Optional[float] = None) -> None:
    """Stops a node.

  Initiates a node stop operation and waits for the node execution to become
  inactive.

  Args:
    mlmd_handle: A handle to the MLMD db.
    node_uid: Uid of the node to be stopped.
    timeout_secs: Amount of time in seconds to wait for node to stop. If `None`,
      waits indefinitely.

  Raises:
    status_lib.StatusNotOkError: Failure to stop the node.
  """
    logging.info('Received request to stop node; node uid: %s', node_uid)
    with _PIPELINE_OPS_LOCK:
        with pstate.PipelineState.load(
                mlmd_handle, node_uid.pipeline_uid) as pipeline_state:
            nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline)
            filtered_nodes = [
                n for n in nodes if n.node_info.id == node_uid.node_id
            ]
            if len(filtered_nodes) != 1:
                raise status_lib.StatusNotOkError(
                    code=status_lib.Code.INTERNAL,
                    message=
                    (f'`stop_node` operation failed, unable to find node to stop: '
                     f'{node_uid}'))
            node = filtered_nodes[0]
            with pipeline_state.node_state_update_context(
                    node_uid) as node_state:
                if node_state.is_stoppable():
                    node_state.update(
                        pstate.NodeState.STOPPING,
                        status_lib.Status(
                            code=status_lib.Code.CANCELLED,
                            message='Cancellation requested by client.'))

    # Wait until the node is stopped or time out.
    _wait_for_node_inactivation(pipeline_state,
                                node_uid,
                                timeout_secs=timeout_secs)
예제 #21
0
    def test_initiate_node_start_stop(self):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            node_uid = task_lib.NodeUid(
                node_id='Trainer',
                pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline))
            with pstate.PipelineState.new(m, pipeline) as pipeline_state:
                pipeline_state.initiate_node_start(node_uid)
                self.assertIsNone(
                    pipeline_state.node_stop_initiated_reason(node_uid))

            # Reload from MLMD and verify node is started.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                self.assertIsNone(
                    pipeline_state.node_stop_initiated_reason(node_uid))

                # Stop the node.
                status = status_lib.Status(code=status_lib.Code.ABORTED,
                                           message='foo bar')
                pipeline_state.initiate_node_stop(node_uid, status)
                self.assertEqual(
                    status,
                    pipeline_state.node_stop_initiated_reason(node_uid))

            # Reload from MLMD and verify node is stopped.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                self.assertEqual(
                    status,
                    pipeline_state.node_stop_initiated_reason(node_uid))

                # Restart node.
                pipeline_state.initiate_node_start(node_uid)
                self.assertIsNone(
                    pipeline_state.node_stop_initiated_reason(node_uid))

            # Reload from MLMD and verify node is started.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                self.assertIsNone(
                    pipeline_state.node_stop_initiated_reason(node_uid))
예제 #22
0
    def test_scheduler_failure(self):
        # Register a fake task scheduler that returns a failure status.
        self._register_task_scheduler(
            ts.TaskSchedulerResult(status=status_lib.Status(
                code=status_lib.Code.ABORTED, message='foobar error'),
                                   executor_output=None))
        task_manager = self._run_task_manager()
        self.assertTrue(task_manager.done())
        self.assertIsNone(task_manager.exception())

        # Check that the task was processed and MLMD execution marked failed.
        self.assertTrue(self._task_queue.is_empty())
        execution = self._get_execution()
        self.assertEqual(metadata_store_pb2.Execution.FAILED,
                         execution.last_known_state)
        self.assertEqual(
            'foobar error',
            data_types_utils.get_metadata_value(execution.custom_properties[
                constants.EXECUTION_ERROR_MSG_KEY]))
예제 #23
0
  def test_task_generation_when_node_stopped(self, stop_transform):
    """Tests stopped nodes are ignored when generating tasks."""
    # Simulate that ExampleGen has already completed successfully.
    test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1,
                                    1)

    # Generate once.
    num_initial_executions = 1
    if stop_transform:
      num_tasks_generated = 1
      num_new_executions = 0
      num_active_executions = 0
      with self._mlmd_connection as m:
        pipeline_state = test_utils.get_or_create_pipeline_state(
            m, self._pipeline)
        with pipeline_state:
          with pipeline_state.node_state_update_context(
              task_lib.NodeUid.from_pipeline_node(
                  self._pipeline, self._transform)) as node_state:
            node_state.update(pstate.NodeState.STOPPING,
                              status_lib.Status(code=status_lib.Code.CANCELLED))
    else:
      num_tasks_generated = 3
      num_new_executions = 1
      num_active_executions = 1
    tasks = self._generate_and_test(
        True,
        num_initial_executions=num_initial_executions,
        num_tasks_generated=num_tasks_generated,
        num_new_executions=num_new_executions,
        num_active_executions=num_active_executions)
    self.assertLen(tasks, num_tasks_generated)

    if stop_transform:
      self.assertTrue(task_lib.is_update_node_state_task(tasks[0]))
      self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state)
    else:
      self.assertTrue(task_lib.is_update_node_state_task(tasks[0]))
      self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state)
      self.assertTrue(task_lib.is_update_node_state_task(tasks[1]))
      self.assertEqual(pstate.NodeState.RUNNING, tasks[1].state)
      self.assertTrue(task_lib.is_exec_node_task(tasks[2]))
예제 #24
0
 def node_stop_initiated_reason(
         self, node_uid: task_lib.NodeUid) -> Optional[status_lib.Status]:
     """Returns status object if node stop initiated, `None` otherwise."""
     if node_uid.pipeline_uid != self.pipeline_uid:
         raise RuntimeError(
             f'Node given by uid {node_uid} does not belong to pipeline given '
             f'by uid {self.pipeline_uid}')
     custom_properties = self.execution.custom_properties
     if _get_metadata_value(
             custom_properties.get(
                 _node_stop_initiated_property(node_uid))) == 1:
         code = _get_metadata_value(
             custom_properties.get(_node_status_code_property(node_uid)))
         if code is None:
             code = status_lib.Code.UNKNOWN
         message = _get_metadata_value(
             custom_properties.get(_node_status_msg_property(node_uid)))
         return status_lib.Status(code=code, message=message)
     else:
         return None
예제 #25
0
def _publish_execution_results(mlmd_handle: metadata.Metadata,
                               task: task_lib.ExecNodeTask,
                               result: ts.TaskSchedulerResult) -> None:
  """Publishes execution results to MLMD."""

  def _update_state(status: status_lib.Status) -> None:
    assert status.code != status_lib.Code.OK
    if status.code == status_lib.Code.CANCELLED:
      logging.info('Cancelling execution (id: %s); task id: %s; status: %s',
                   task.execution.id, task.task_id, status)
      execution_state = metadata_store_pb2.Execution.CANCELED
    else:
      logging.info(
          'Aborting execution (id: %s) due to error (code: %s); task id: %s',
          task.execution.id, status.code, task.task_id)
      execution_state = metadata_store_pb2.Execution.FAILED
    _update_execution_state_in_mlmd(mlmd_handle, task.execution,
                                    execution_state, status.message)

  if result.status.code != status_lib.Code.OK:
    _update_state(result.status)
    return

  publish_params = dict(output_artifacts=task.output_artifacts)
  if result.output_artifacts is not None:
    publish_params['output_artifacts'] = result.output_artifacts
  elif result.executor_output is not None:
    if result.executor_output.execution_result.code != status_lib.Code.OK:
      _update_state(
          status_lib.Status(
              code=result.executor_output.execution_result.code,
              message=result.executor_output.execution_result.result_message))
      return
    publish_params['executor_output'] = result.executor_output

  execution_publish_utils.publish_succeeded_execution(mlmd_handle,
                                                      task.execution.id,
                                                      task.contexts,
                                                      **publish_params)
예제 #26
0
    def test_handling_finalize_node_task(self, task_gen):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline_ops.initiate_pipeline_start(m, pipeline)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED,
                                                message='foo bar')
            task_gen.return_value.generate.side_effect = [
                [
                    test_utils.create_exec_node_task(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Transform')),
                    task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid(
                        pipeline_uid=pipeline_uid, node_id='Trainer'),
                                              status=finalize_reason)
                ],
            ]

            task_queue = tq.TaskQueue()
            pipeline_ops.orchestrate(m, task_queue,
                                     service_jobs.DummyServiceJobManager())
            task_gen.return_value.generate.assert_called_once()
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_exec_node_task(task))
            self.assertEqual(
                test_utils.create_node_uid('pipeline1', 'Transform'),
                task.node_uid)

            # Load pipeline state and verify node stop initiation.
            with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
                self.assertEqual(
                    finalize_reason,
                    pipeline_state.node_stop_initiated_reason(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Trainer')))
예제 #27
0
    def test_successful_execution_resulting_in_output_artifacts(self):
        # Register a fake task scheduler that returns a successful execution result
        # and `OK` task scheduler status.
        self._register_task_scheduler(
            ts.TaskSchedulerResult(
                status=status_lib.Status(code=status_lib.Code.OK),
                output=ts.ImporterNodeOutput(
                    output_artifacts=self._task.output_artifacts)))
        task_manager = self._run_task_manager()
        self.assertTrue(task_manager.done())
        self.assertIsNone(task_manager.exception())

        # Check that the task was processed and MLMD execution marked successful.
        self.assertTrue(self._task_queue.is_empty())
        execution = self._get_execution()
        self.assertEqual(metadata_store_pb2.Execution.COMPLETE,
                         execution.last_known_state)

        # Check that stateful working dir and tmp_dir are removed.
        self.assertFalse(os.path.exists(self._task.stateful_working_dir))
        self.assertFalse(os.path.exists(self._task.tmp_dir))
        # Output artifact URI remains as execution was successful.
        self.assertTrue(os.path.exists(self._output_artifact_uri))
예제 #28
0
파일: task_manager.py 프로젝트: htahir1/tfx
 def _process_exec_node_task(self, scheduler: ts.TaskScheduler,
                             task: task_lib.ExecNodeTask) -> None:
   """Processes an `ExecNodeTask` using the given task scheduler."""
   # This is a blocking call to the scheduler which can take a long time to
   # complete for some types of task schedulers. The scheduler is expected to
   # handle any internal errors gracefully and return the result with an error
   # status. But in case the scheduler raises an exception, it is considered
   # a failed execution and MLMD is updated accordingly.
   try:
     result = scheduler.schedule()
   except Exception as e:  # pylint: disable=broad-except
     logging.exception('Exception raised by task scheduler; node uid: %s',
                       task.node_uid)
     result = ts.TaskSchedulerResult(
         status=status_lib.Status(
             code=status_lib.Code.ABORTED, message=str(e)))
   logging.info('For ExecNodeTask id: %s, task-scheduler result status: %s',
                task.task_id, result.status)
   _publish_execution_results(
       mlmd_handle=self._mlmd_handle, task=task, result=result)
   with self._tm_lock:
     del self._scheduler_by_node_uid[task.node_uid]
     self._task_queue.task_done(task)
예제 #29
0
    def test_handling_finalize_pipeline_task(self, task_gen):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC)
            pipeline_ops.initiate_pipeline_start(m, pipeline)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED,
                                                message='foo bar')
            task_gen.return_value.generate.side_effect = [
                [
                    task_lib.FinalizePipelineTask(pipeline_uid=pipeline_uid,
                                                  status=finalize_reason)
                ],
            ]

            task_queue = tq.TaskQueue()
            pipeline_ops.orchestrate(m, task_queue,
                                     service_jobs.DummyServiceJobManager())
            task_gen.return_value.generate.assert_called_once()
            self.assertTrue(task_queue.is_empty())

            # Load pipeline state and verify stop initiation.
            with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
                self.assertEqual(finalize_reason,
                                 pipeline_state.stop_initiated_reason())
예제 #30
0
def _publish_execution_results(mlmd_handle: metadata.Metadata,
                               task: task_lib.ExecNodeTask,
                               result: ts.TaskSchedulerResult) -> None:
    """Publishes execution results to MLMD."""
    def _update_state(status: status_lib.Status) -> None:
        assert status.code != status_lib.Code.OK
        _remove_output_dirs(task, result)
        _remove_task_dirs(task)
        if status.code == status_lib.Code.CANCELLED:
            logging.info(
                'Cancelling execution (id: %s); task id: %s; status: %s',
                task.execution_id, task.task_id, status)
            execution_state = metadata_store_pb2.Execution.CANCELED
        else:
            logging.info(
                'Aborting execution (id: %s) due to error (code: %s); task id: %s',
                task.execution_id, status.code, task.task_id)
            execution_state = metadata_store_pb2.Execution.FAILED
        _update_execution_state_in_mlmd(mlmd_handle, task.execution_id,
                                        execution_state, status.message)
        pipeline_state.record_state_change_time()

    if result.status.code != status_lib.Code.OK:
        _update_state(result.status)
        return

    # TODO(b/182316162): Unify publisher handing so that post-execution artifact
    # logic is more cleanly handled.
    outputs_utils.tag_output_artifacts_with_version(task.output_artifacts)
    if isinstance(result.output, ts.ExecutorNodeOutput):
        executor_output = result.output.executor_output
        if executor_output is not None:
            if executor_output.execution_result.code != status_lib.Code.OK:
                _update_state(
                    status_lib.Status(
                        code=executor_output.execution_result.code,
                        message=executor_output.execution_result.result_message
                    ))
                return
            # TODO(b/182316162): Unify publisher handing so that post-execution
            # artifact logic is more cleanly handled.
            outputs_utils.tag_executor_output_with_version(executor_output)
        _remove_task_dirs(task)
        execution_publish_utils.publish_succeeded_execution(
            mlmd_handle,
            execution_id=task.execution_id,
            contexts=task.contexts,
            output_artifacts=task.output_artifacts,
            executor_output=executor_output)
    elif isinstance(result.output, ts.ImporterNodeOutput):
        output_artifacts = result.output.output_artifacts
        # TODO(b/182316162): Unify publisher handing so that post-execution artifact
        # logic is more cleanly handled.
        outputs_utils.tag_output_artifacts_with_version(output_artifacts)
        _remove_task_dirs(task)
        execution_publish_utils.publish_succeeded_execution(
            mlmd_handle,
            execution_id=task.execution_id,
            contexts=task.contexts,
            output_artifacts=output_artifacts)
    elif isinstance(result.output, ts.ResolverNodeOutput):
        resolved_input_artifacts = result.output.resolved_input_artifacts
        execution_publish_utils.publish_internal_execution(
            mlmd_handle,
            execution_id=task.execution_id,
            contexts=task.contexts,
            output_artifacts=resolved_input_artifacts)
    else:
        raise TypeError(f'Unable to process task scheduler result: {result}')

    pipeline_state.record_state_change_time()