Exemplo n.º 1
0
def _wait_for_inactivation(
        mlmd_handle: metadata.Metadata,
        execution_id: metadata_store_pb2.Execution,
        timeout_secs: float = DEFAULT_WAIT_FOR_INACTIVATION_TIMEOUT_SECS
) -> None:
    """Waits for the given execution to become inactive.

  Args:
    mlmd_handle: A handle to the MLMD db.
    execution_id: Id of the execution whose inactivation is awaited.
    timeout_secs: Amount of time in seconds to wait.

  Raises:
    StatusNotOkError: With error code `DEADLINE_EXCEEDED` if execution is not
      inactive after waiting approx. `timeout_secs`.
  """
    polling_interval_secs = min(10.0, timeout_secs / 4)
    end_time = time.time() + timeout_secs
    while end_time - time.time() > 0:
        updated_executions = mlmd_handle.store.get_executions_by_id(
            [execution_id])
        if not execution_lib.is_execution_active(updated_executions[0]):
            return
        time.sleep(max(0, min(polling_interval_secs, end_time - time.time())))
    raise status_lib.StatusNotOkError(
        code=status_lib.Code.DEADLINE_EXCEEDED,
        message=(f'Timed out ({timeout_secs} secs) waiting for execution '
                 f'inactivation.'))
Exemplo n.º 2
0
def initiate_pipeline_start(
        mlmd_handle: metadata.Metadata,
        pipeline: pipeline_pb2.Pipeline) -> pstate.PipelineState:
    """Initiates a pipeline start operation.

  Upon success, MLMD is updated to signal that the pipeline must be started.

  Args:
    mlmd_handle: A handle to the MLMD db.
    pipeline: IR of the pipeline to start.

  Returns:
    The `PipelineState` object upon success.

  Raises:
    status_lib.StatusNotOkError: Failure to initiate pipeline start. With code
      `INVALILD_ARGUMENT` if it's a sync pipeline without `pipeline_run_id`
      provided.
  """
    pipeline = copy.deepcopy(pipeline)
    if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC and not (
            pipeline.runtime_spec.pipeline_run_id.HasField('field_value') and
            pipeline.runtime_spec.pipeline_run_id.field_value.string_value):
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.INVALID_ARGUMENT,
            message='Sync pipeline IR must specify pipeline_run_id.')

    return pstate.PipelineState.new(mlmd_handle, pipeline)
Exemplo n.º 3
0
 def get_node_state(self, node_uid: task_lib.NodeUid) -> NodeState:
     self._check_context()
     if not _is_node_uid_in_pipeline(node_uid, self.pipeline):
         raise status_lib.StatusNotOkError(
             code=status_lib.Code.INVALID_ARGUMENT,
             message=(f'Node {node_uid} does not belong to the pipeline '
                      f'{self.pipeline_uid}'))
     node_states_dict = _get_node_states_dict(self._execution)
     return node_states_dict.get(node_uid.node_id, NodeState())
Exemplo n.º 4
0
def orchestrate(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
                service_job_manager: service_jobs.ServiceJobManager) -> None:
    """Performs a single iteration of the orchestration loop.

  Embodies the core functionality of the main orchestration loop that scans MLMD
  pipeline execution states, generates and enqueues the tasks to be performed.

  Args:
    mlmd_handle: A handle to the MLMD db.
    task_queue: A `TaskQueue` instance into which any tasks will be enqueued.
    service_job_manager: A `ServiceJobManager` instance for handling service
      jobs.

  Raises:
    status_lib.StatusNotOkError: If error generating tasks.
  """
    pipeline_states = _get_pipeline_states(mlmd_handle)
    if not pipeline_states:
        logging.info('No active pipelines to run.')
        return

    active_pipeline_states = []
    stop_initiated_pipeline_states = []
    update_initiated_pipeline_states = []
    for pipeline_state in pipeline_states:
        with pipeline_state:
            if pipeline_state.is_stop_initiated():
                stop_initiated_pipeline_states.append(pipeline_state)
            elif pipeline_state.is_update_initiated():
                update_initiated_pipeline_states.append(pipeline_state)
            elif pipeline_state.is_active():
                active_pipeline_states.append(pipeline_state)
            else:
                raise status_lib.StatusNotOkError(
                    code=status_lib.Code.INTERNAL,
                    message=(
                        f'Found pipeline (uid: {pipeline_state.pipeline_uid}) '
                        f'which is neither active nor stop-initiated.'))

    for pipeline_state in stop_initiated_pipeline_states:
        logging.info('Orchestrating stop-initiated pipeline: %s',
                     pipeline_state.pipeline_uid)
        _orchestrate_stop_initiated_pipeline(mlmd_handle, task_queue,
                                             service_job_manager,
                                             pipeline_state)

    for pipeline_state in update_initiated_pipeline_states:
        logging.info('Orchestrating update-initiated pipeline: %s',
                     pipeline_state.pipeline_uid)
        _orchestrate_update_initiated_pipeline(mlmd_handle, task_queue,
                                               service_job_manager,
                                               pipeline_state)

    for pipeline_state in active_pipeline_states:
        logging.info('Orchestrating pipeline: %s', pipeline_state.pipeline_uid)
        _orchestrate_active_pipeline(mlmd_handle, task_queue,
                                     service_job_manager, pipeline_state)
Exemplo n.º 5
0
 def _wrapper(*args, **kwargs):
     try:
         return fn(*args, **kwargs)
     except Exception as e:  # pylint: disable=broad-except
         logging.exception('Error raised by `%s`:', fn.__name__)
         if isinstance(e, status_lib.StatusNotOkError):
             raise
         raise status_lib.StatusNotOkError(
             code=status_lib.Code.UNKNOWN,
             message=f'`{fn.__name__}` error: {str(e)}')
Exemplo n.º 6
0
 def initiate_node_start(self, node_uid: task_lib.NodeUid) -> None:
     """Updates pipeline state to signal that a node should be started."""
     if self.pipeline.execution_mode != pipeline_pb2.Pipeline.ASYNC:
         raise status_lib.StatusNotOkError(
             code=status_lib.Code.UNIMPLEMENTED,
             message='Node can be started only for async pipelines.')
     if not _is_node_uid_in_pipeline(node_uid, self.pipeline):
         raise status_lib.StatusNotOkError(
             code=status_lib.Code.INVALID_ARGUMENT,
             message=(
                 f'Node given by uid {node_uid} does not belong to pipeline '
                 f'given by uid {self.pipeline_uid}'))
     if self.execution.custom_properties.pop(
             _node_stop_initiated_property(node_uid), None) is not None:
         self.execution.custom_properties.pop(
             _node_status_code_property(node_uid), None)
         self.execution.custom_properties.pop(
             _node_status_msg_property(node_uid), None)
         self._commit = True
Exemplo n.º 7
0
def _get_active_execution(
    pipeline_uid: task_lib.PipelineUid,
    executions: List[metadata_store_pb2.Execution]
) -> metadata_store_pb2.Execution:
    """gets a single active execution from the executions."""
    active_executions = [
        e for e in executions if execution_lib.is_execution_active(e)
    ]
    if not active_executions:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.NOT_FOUND,
            message=f'No active pipeline with uid {pipeline_uid} to load state.'
        )
    if len(active_executions) > 1:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.INTERNAL,
            message=
            (f'Expected 1 but found {len(active_executions)} active pipeline '
             f'executions for pipeline uid: {pipeline_uid}'))
    return active_executions[0]
Exemplo n.º 8
0
    def load(cls,
             mlmd_handle: metadata.Metadata,
             pipeline_uid: task_lib.PipelineUid,
             pipeline_run_id: Optional[str] = None) -> 'PipelineView':
        """Loads pipeline view from MLMD.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline_uid: Uid of the pipeline state to load.
      pipeline_run_id: Run id of the pipeline for the synchronous pipeline.

    Returns:
      A `PipelineView` object.

    Raises:
      status_lib.StatusNotOkError: With code=NOT_FOUND if no pipeline
      with the given pipeline uid exists in MLMD. With code=INTERNAL if more
      than 1 active execution exists for given pipeline uid when pipeline_run_id
      is not specified.

    """
        context = mlmd_handle.store.get_context_by_type_and_name(
            type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))
        if not context:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.NOT_FOUND,
                message=f'No pipeline with uid {pipeline_uid} found.')
        executions = mlmd_handle.store.get_executions_by_context(context.id)

        if pipeline_run_id is None and executions:
            execution = _get_latest_execution(executions)
            return cls(pipeline_uid, context, execution)

        for execution in executions:
            if execution.custom_properties[
                    _PIPELINE_RUN_ID].string_value == pipeline_run_id:
                return cls(pipeline_uid, context, execution)
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.NOT_FOUND,
            message=f'No pipeline with run_id {pipeline_run_id} found.')
Exemplo n.º 9
0
    def load_from_orchestrator_context(
            cls, mlmd_handle: metadata.Metadata,
            context: metadata_store_pb2.Context) -> 'PipelineState':
        """Loads pipeline state for active pipeline under given orchestrator context.

    Args:
      mlmd_handle: A handle to the MLMD db.
      context: Pipeline context under which to find the pipeline execution.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: With code=NOT_FOUND if no active pipeline
      exists for the given context in MLMD. With code=INTERNAL if more than 1
      active execution exists for given pipeline uid.
    """
        pipeline_uid = pipeline_uid_from_orchestrator_context(context)
        active_executions = [
            e for e in mlmd_handle.store.get_executions_by_context(context.id)
            if execution_lib.is_execution_active(e)
        ]
        if not active_executions:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.NOT_FOUND,
                message=
                f'No active pipeline with uid {pipeline_uid} to load state.')
        if len(active_executions) > 1:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.INTERNAL,
                message=
                (f'Expected 1 but found {len(active_executions)} active pipeline '
                 f'executions for pipeline uid: {pipeline_uid}'))

        return cls(mlmd_handle=mlmd_handle,
                   pipeline_uid=pipeline_uid,
                   context=context,
                   execution=active_executions[0],
                   commit=False)
Exemplo n.º 10
0
    def new(cls, mlmd_handle: metadata.Metadata,
            pipeline: pipeline_pb2.Pipeline) -> 'PipelineState':
        """Creates a `PipelineState` object for a new pipeline.

    No active pipeline with the same pipeline uid should exist for the call to
    be successful.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: IR of the pipeline.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: If a pipeline with same UID already exists.
    """
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        context = context_lib.register_context_if_not_exists(
            mlmd_handle,
            context_type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))

        executions = mlmd_handle.store.get_executions_by_context(context.id)
        if any(e for e in executions if execution_lib.is_execution_active(e)):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=f'Pipeline with uid {pipeline_uid} already active.')

        execution = execution_lib.prepare_execution(
            mlmd_handle,
            _ORCHESTRATOR_EXECUTION_TYPE,
            metadata_store_pb2.Execution.NEW,
            exec_properties={
                _PIPELINE_IR:
                base64.b64encode(pipeline.SerializeToString()).decode('utf-8')
            },
        )
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            data_types_utils.set_metadata_value(
                execution.custom_properties[_PIPELINE_RUN_ID],
                pipeline.runtime_spec.pipeline_run_id.field_value.string_value)

        execution = execution_lib.put_execution(mlmd_handle, execution,
                                                [context])
        record_state_change_time()

        return cls(mlmd_handle=mlmd_handle,
                   pipeline=pipeline,
                   execution_id=execution.id)
Exemplo n.º 11
0
 def apply_pipeline_update(self) -> None:
     """Applies pipeline update that was previously initiated."""
     self._check_context()
     updated_pipeline_ir = _get_metadata_value(
         self._execution.custom_properties.get(_UPDATED_PIPELINE_IR))
     if not updated_pipeline_ir:
         raise status_lib.StatusNotOkError(
             code=status_lib.Code.INVALID_ARGUMENT,
             message='No updated pipeline IR to apply')
     data_types_utils.set_metadata_value(
         self._execution.properties[_PIPELINE_IR], updated_pipeline_ir)
     del self._execution.custom_properties[_UPDATED_PIPELINE_IR]
     del self._execution.custom_properties[_UPDATE_OPTIONS]
     self.pipeline = _base64_decode_pipeline(updated_pipeline_ir)
Exemplo n.º 12
0
 def initiate_node_stop(self, node_uid: task_lib.NodeUid,
                        status: status_lib.Status) -> None:
     """Updates pipeline state to signal that a node should be stopped."""
     self._check_context()
     if self.pipeline.execution_mode != pipeline_pb2.Pipeline.ASYNC:
         raise status_lib.StatusNotOkError(
             code=status_lib.Code.UNIMPLEMENTED,
             message='Node can be stopped only for async pipelines.')
     if not _is_node_uid_in_pipeline(node_uid, self.pipeline):
         raise status_lib.StatusNotOkError(
             code=status_lib.Code.INVALID_ARGUMENT,
             message=(
                 f'Node given by uid {node_uid} does not belong to pipeline '
                 f'given by uid {self.pipeline_uid}'))
     data_types_utils.set_metadata_value(
         self._execution.custom_properties[_node_stop_initiated_property(
             node_uid)], 1)
     data_types_utils.set_metadata_value(
         self._execution.custom_properties[_node_status_code_property(
             node_uid)], int(status.code))
     if status.message:
         data_types_utils.set_metadata_value(
             self._execution.custom_properties[_node_status_msg_property(
                 node_uid)], status.message)
Exemplo n.º 13
0
 def node_state_update_context(
         self, node_uid: task_lib.NodeUid) -> Iterator[NodeState]:
     """Context manager for updating the node state."""
     self._check_context()
     if not _is_node_uid_in_pipeline(node_uid, self.pipeline):
         raise status_lib.StatusNotOkError(
             code=status_lib.Code.INVALID_ARGUMENT,
             message=(f'Node {node_uid} does not belong to the pipeline '
                      f'{self.pipeline_uid}'))
     node_states_dict = _get_node_states_dict(self._execution)
     node_state = node_states_dict.setdefault(node_uid.node_id, NodeState())
     old_state = node_state.state
     yield node_state
     if old_state != node_state.state:
         logging.info('Changing node state: %s -> %s; node uid: %s',
                      old_state, node_state.state, node_uid)
     _save_node_states_dict(self._execution, node_states_dict)
Exemplo n.º 14
0
def stop_node(mlmd_handle: metadata.Metadata,
              node_uid: task_lib.NodeUid,
              timeout_secs: Optional[float] = None) -> None:
    """Stops a node.

  Initiates a node stop operation and waits for the node execution to become
  inactive.

  Args:
    mlmd_handle: A handle to the MLMD db.
    node_uid: Uid of the node to be stopped.
    timeout_secs: Amount of time in seconds to wait for node to stop. If `None`,
      waits indefinitely.

  Raises:
    status_lib.StatusNotOkError: Failure to stop the node.
  """
    logging.info('Received request to stop node; node uid: %s', node_uid)
    with _PIPELINE_OPS_LOCK:
        with pstate.PipelineState.load(
                mlmd_handle, node_uid.pipeline_uid) as pipeline_state:
            nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline)
            filtered_nodes = [
                n for n in nodes if n.node_info.id == node_uid.node_id
            ]
            if len(filtered_nodes) != 1:
                raise status_lib.StatusNotOkError(
                    code=status_lib.Code.INTERNAL,
                    message=
                    (f'`stop_node` operation failed, unable to find node to stop: '
                     f'{node_uid}'))
            node = filtered_nodes[0]
            with pipeline_state.node_state_update_context(
                    node_uid) as node_state:
                if node_state.is_stoppable():
                    node_state.update(
                        pstate.NodeState.STOPPING,
                        status_lib.Status(
                            code=status_lib.Code.CANCELLED,
                            message='Cancellation requested by client.'))

    # Wait until the node is stopped or time out.
    _wait_for_node_inactivation(pipeline_state,
                                node_uid,
                                timeout_secs=timeout_secs)
Exemplo n.º 15
0
def _wait_for_predicate(predicate_fn: Callable[[],
                                               bool], waiting_for_desc: str,
                        timeout_secs: Optional[float]) -> None:
    """Waits for `predicate_fn` to return `True` or until timeout seconds elapse."""
    if timeout_secs is None:
        while not predicate_fn():
            time.sleep(_POLLING_INTERVAL_SECS)
        return
    polling_interval_secs = min(_POLLING_INTERVAL_SECS, timeout_secs / 4)
    end_time = time.time() + timeout_secs
    while end_time - time.time() > 0:
        if predicate_fn():
            return
        time.sleep(max(0, min(polling_interval_secs, end_time - time.time())))
    raise status_lib.StatusNotOkError(
        code=status_lib.Code.DEADLINE_EXCEEDED,
        message=(
            f'Timed out ({timeout_secs} secs) waiting for {waiting_for_desc}.'
        ))
Exemplo n.º 16
0
    def load(cls, mlmd_handle: metadata.Metadata,
             pipeline_uid: task_lib.PipelineUid) -> 'PipelineState':
        """Loads pipeline state from MLMD.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline_uid: Uid of the pipeline state to load.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: With code=NOT_FOUND if no active pipeline
      with the given pipeline uid exists in MLMD. With code=INTERNAL if more
      than 1 active execution exists for given pipeline uid.
    """
        context = mlmd_handle.store.get_context_by_type_and_name(
            type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))
        if not context:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.NOT_FOUND,
                message=f'No pipeline with uid {pipeline_uid} found.')
        return cls.load_from_orchestrator_context(mlmd_handle, context)
Exemplo n.º 17
0
def _orchestrate_active_pipeline(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> None:
    """Orchestrates active pipeline."""
    pipeline = pipeline_state.pipeline
    with pipeline_state:
        assert pipeline_state.is_active()
        if pipeline_state.get_pipeline_execution_state() != (
                metadata_store_pb2.Execution.RUNNING):
            pipeline_state.set_pipeline_execution_state(
                metadata_store_pb2.Execution.RUNNING)
        orchestration_options = pipeline_state.get_orchestration_options()
        logging.info('Orchestration options: %s', orchestration_options)
        deadline_secs = orchestration_options.deadline_secs
        if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                and deadline_secs > 0 and time.time() -
                pipeline_state.pipeline_creation_time_secs_since_epoch() >
                deadline_secs):
            logging.error(
                'Aborting pipeline due to exceeding deadline (%s secs); '
                'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid)
            pipeline_state.initiate_stop(
                status_lib.Status(
                    code=status_lib.Code.DEADLINE_EXCEEDED,
                    message=('Pipeline aborted due to exceeding deadline '
                             f'({deadline_secs} secs)')))
            return

    def _filter_by_state(node_infos: List[_NodeInfo],
                         state_str: str) -> List[_NodeInfo]:
        return [n for n in node_infos if n.state.state == state_str]

    node_infos = _get_node_infos(pipeline_state)
    stopping_node_infos = _filter_by_state(node_infos,
                                           pstate.NodeState.STOPPING)

    # Tracks nodes stopped in the current iteration.
    stopped_node_infos: List[_NodeInfo] = []

    # Create cancellation tasks for nodes in state STOPPING.
    for node_info in stopping_node_infos:
        if service_job_manager.is_pure_service_node(
                pipeline_state, node_info.node.node_info.id):
            if service_job_manager.stop_node_services(
                    pipeline_state, node_info.node.node_info.id):
                stopped_node_infos.append(node_info)
        elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline,
                                              node_info.node, task_queue):
            pass
        elif service_job_manager.is_mixed_service_node(
                pipeline_state, node_info.node.node_info.id):
            if service_job_manager.stop_node_services(
                    pipeline_state, node_info.node.node_info.id):
                stopped_node_infos.append(node_info)
        else:
            stopped_node_infos.append(node_info)

    # Change the state of stopped nodes from STOPPING to STOPPED.
    if stopped_node_infos:
        with pipeline_state:
            for node_info in stopped_node_infos:
                node_uid = task_lib.NodeUid.from_pipeline_node(
                    pipeline, node_info.node)
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STOPPED,
                                      node_state.status)

    # Initialize task generator for the pipeline.
    if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle,
            task_queue.contains_task_id,
            service_job_manager,
            fail_fast=orchestration_options.fail_fast)
    elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, task_queue.contains_task_id, service_job_manager)
    else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    tasks = generator.generate(pipeline_state)

    with pipeline_state:
        # Handle all the UpdateNodeStateTasks by updating node states.
        for task in tasks:
            if task_lib.is_update_node_state_task(task):
                task = typing.cast(task_lib.UpdateNodeStateTask, task)
                with pipeline_state.node_state_update_context(
                        task.node_uid) as node_state:
                    node_state.update(task.state, task.status)

        tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)]

        # If there are still nodes in state STARTING, change them to STARTED.
        for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline):
            node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline_state.pipeline, node)
            with pipeline_state.node_state_update_context(
                    node_uid) as node_state:
                if node_state.state == pstate.NodeState.STARTING:
                    node_state.update(pstate.NodeState.STARTED)

        for task in tasks:
            if task_lib.is_exec_node_task(task):
                task = typing.cast(task_lib.ExecNodeTask, task)
                task_queue.enqueue(task)
            else:
                assert task_lib.is_finalize_pipeline_task(task)
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                assert len(tasks) == 1
                task = typing.cast(task_lib.FinalizePipelineTask, task)
                if task.status.code == status_lib.Code.OK:
                    logging.info('Pipeline run successful; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                else:
                    logging.info('Pipeline run failed; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                pipeline_state.initiate_stop(task.status)
Exemplo n.º 18
0
def _orchestrate_active_pipeline(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> None:
    """Orchestrates active pipeline."""
    pipeline = pipeline_state.pipeline
    execution = pipeline_state.execution
    assert execution.last_known_state in (metadata_store_pb2.Execution.NEW,
                                          metadata_store_pb2.Execution.RUNNING)
    if execution.last_known_state != metadata_store_pb2.Execution.RUNNING:
        updated_execution = copy.deepcopy(execution)
        updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING
        mlmd_handle.store.put_executions([updated_execution])

    # Initialize task generator for the pipeline.
    if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle, pipeline_state, task_queue.contains_task_id,
            service_job_manager)
    elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        # Create cancellation tasks for stop-initiated nodes if necessary.
        stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state)
        for node in stop_initiated_nodes:
            if service_job_manager.is_pure_service_node(
                    pipeline_state, node.node_info.id):
                service_job_manager.stop_node_services(pipeline_state,
                                                       node.node_info.id)
            elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node,
                                                  task_queue):
                pass
            elif service_job_manager.is_mixed_service_node(
                    pipeline_state, node.node_info.id):
                service_job_manager.stop_node_services(pipeline_state,
                                                       node.node_info.id)
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, pipeline_state, task_queue.contains_task_id,
            service_job_manager,
            set(n.node_info.id for n in stop_initiated_nodes))
    else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    tasks = generator.generate()

    with pipeline_state:
        for task in tasks:
            if task_lib.is_exec_node_task(task):
                task = typing.cast(task_lib.ExecNodeTask, task)
                task_queue.enqueue(task)
            elif task_lib.is_finalize_node_task(task):
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC
                task = typing.cast(task_lib.FinalizeNodeTask, task)
                pipeline_state.initiate_node_stop(task.node_uid, task.status)
            else:
                assert task_lib.is_finalize_pipeline_task(task)
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                assert len(tasks) == 1
                task = typing.cast(task_lib.FinalizePipelineTask, task)
                if task.status.code == status_lib.Code.OK:
                    logging.info('Pipeline run successful; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                else:
                    logging.info('Pipeline run failed; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                pipeline_state.initiate_stop(task.status)
Exemplo n.º 19
0
 def fn2():
     raise status_lib.StatusNotOkError(
         code=status_lib.Code.ALREADY_EXISTS, message='test error 2')
Exemplo n.º 20
0
    def new(
        cls,
        mlmd_handle: metadata.Metadata,
        pipeline: pipeline_pb2.Pipeline,
        pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None,
    ) -> 'PipelineState':
        """Creates a `PipelineState` object for a new pipeline.

    No active pipeline with the same pipeline uid should exist for the call to
    be successful.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: IR of the pipeline.
      pipeline_run_metadata: Pipeline run metadata.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: If a pipeline with same UID already exists.
    """
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        context = context_lib.register_context_if_not_exists(
            mlmd_handle,
            context_type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))

        executions = mlmd_handle.store.get_executions_by_context(context.id)
        if any(e for e in executions if execution_lib.is_execution_active(e)):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=f'Pipeline with uid {pipeline_uid} already active.')

        exec_properties = {_PIPELINE_IR: _base64_encode(pipeline)}
        if pipeline_run_metadata:
            exec_properties[_PIPELINE_RUN_METADATA] = json_utils.dumps(
                pipeline_run_metadata)

        execution = execution_lib.prepare_execution(
            mlmd_handle,
            _ORCHESTRATOR_EXECUTION_TYPE,
            metadata_store_pb2.Execution.NEW,
            exec_properties=exec_properties)
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            data_types_utils.set_metadata_value(
                execution.custom_properties[_PIPELINE_RUN_ID],
                pipeline.runtime_spec.pipeline_run_id.field_value.string_value)
            # Set the node state to COMPLETE for any nodes that are marked to be
            # skipped in a partial pipeline run.
            node_states_dict = {}
            for node in get_all_pipeline_nodes(pipeline):
                if node.execution_options.HasField('skip'):
                    logging.info('Node %s is skipped in this partial run.',
                                 node.node_info.id)
                    node_states_dict[node.node_info.id] = NodeState(
                        state=NodeState.COMPLETE)
            if node_states_dict:
                _save_node_states_dict(execution, node_states_dict)

        execution = execution_lib.put_execution(mlmd_handle, execution,
                                                [context])
        record_state_change_time()

        return cls(mlmd_handle=mlmd_handle,
                   pipeline=pipeline,
                   execution_id=execution.id)
Exemplo n.º 21
0
def resume_pipeline(mlmd_handle: metadata.Metadata,
                    pipeline: pipeline_pb2.Pipeline) -> pstate.PipelineState:
    """Resumes a pipeline run from previously failed nodes.

  Upon success, MLMD is updated to signal that the pipeline must be started.

  Args:
    mlmd_handle: A handle to the MLMD db.
    pipeline: IR of the pipeline to resume.

  Returns:
    The `PipelineState` object upon success.

  Raises:
    status_lib.StatusNotOkError: Failure to resume pipeline. With code
      `ALREADY_EXISTS` if a pipeline is already running. With code
      `status_lib.Code.FAILED_PRECONDITION` if a previous pipeline run
      is not found for resuming.
  """

    logging.info('Received request to resume pipeline; pipeline uid: %s',
                 task_lib.PipelineUid.from_pipeline(pipeline))
    if pipeline.execution_mode != pipeline_pb2.Pipeline.SYNC:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    latest_pipeline_view = None
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    views = pstate.PipelineView.load_all(mlmd_handle, pipeline_uid)
    for view in views:
        execution = view.execution
        if execution_lib.is_execution_active(execution):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=(
                    f'Can not resume pipeline. An active pipeline is already '
                    f'running with uid {pipeline_uid}.'))
        if (not latest_pipeline_view or execution.create_time_since_epoch >
                latest_pipeline_view.execution.create_time_since_epoch):
            latest_pipeline_view = view

    if not latest_pipeline_view:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.NOT_FOUND,
            message='Pipeline failed to resume. No previous pipeline run found.'
        )
    if latest_pipeline_view.pipeline.execution_mode != pipeline_pb2.Pipeline.SYNC:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=
            (f'Only SYNC pipeline execution modes supported; previous pipeline '
             f'run has execution mode: '
             f'{latest_pipeline_view.pipeline.execution_mode}'))

    # Get succeeded nodes in latest pipeline run.
    latest_pipeline_node_states = latest_pipeline_view.get_node_states_dict()
    previously_succeeded_nodes = []
    for node, node_state in latest_pipeline_node_states.items():
        if node_state.is_success():
            previously_succeeded_nodes.append(node)
    pipeline_nodes = [
        node.node_info.id for node in pstate.get_all_pipeline_nodes(pipeline)
    ]
    latest_pipeline_snapshot_settings = pipeline_pb2.SnapshotSettings()
    latest_pipeline_snapshot_settings.latest_pipeline_run_strategy.SetInParent(
    )
    partial_run_option = pipeline_pb2.PartialRun(
        from_nodes=pipeline_nodes,
        to_nodes=pipeline_nodes,
        skip_nodes=previously_succeeded_nodes,
        snapshot_settings=latest_pipeline_snapshot_settings)

    return initiate_pipeline_start(mlmd_handle,
                                   pipeline,
                                   partial_run_option=partial_run_option)