def _wait_for_inactivation( mlmd_handle: metadata.Metadata, execution_id: metadata_store_pb2.Execution, timeout_secs: float = DEFAULT_WAIT_FOR_INACTIVATION_TIMEOUT_SECS ) -> None: """Waits for the given execution to become inactive. Args: mlmd_handle: A handle to the MLMD db. execution_id: Id of the execution whose inactivation is awaited. timeout_secs: Amount of time in seconds to wait. Raises: StatusNotOkError: With error code `DEADLINE_EXCEEDED` if execution is not inactive after waiting approx. `timeout_secs`. """ polling_interval_secs = min(10.0, timeout_secs / 4) end_time = time.time() + timeout_secs while end_time - time.time() > 0: updated_executions = mlmd_handle.store.get_executions_by_id( [execution_id]) if not execution_lib.is_execution_active(updated_executions[0]): return time.sleep(max(0, min(polling_interval_secs, end_time - time.time()))) raise status_lib.StatusNotOkError( code=status_lib.Code.DEADLINE_EXCEEDED, message=(f'Timed out ({timeout_secs} secs) waiting for execution ' f'inactivation.'))
def initiate_pipeline_start( mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline) -> pstate.PipelineState: """Initiates a pipeline start operation. Upon success, MLMD is updated to signal that the pipeline must be started. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline to start. Returns: The `PipelineState` object upon success. Raises: status_lib.StatusNotOkError: Failure to initiate pipeline start. With code `INVALILD_ARGUMENT` if it's a sync pipeline without `pipeline_run_id` provided. """ pipeline = copy.deepcopy(pipeline) if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC and not ( pipeline.runtime_spec.pipeline_run_id.HasField('field_value') and pipeline.runtime_spec.pipeline_run_id.field_value.string_value): raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message='Sync pipeline IR must specify pipeline_run_id.') return pstate.PipelineState.new(mlmd_handle, pipeline)
def get_node_state(self, node_uid: task_lib.NodeUid) -> NodeState: self._check_context() if not _is_node_uid_in_pipeline(node_uid, self.pipeline): raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message=(f'Node {node_uid} does not belong to the pipeline ' f'{self.pipeline_uid}')) node_states_dict = _get_node_states_dict(self._execution) return node_states_dict.get(node_uid.node_id, NodeState())
def orchestrate(mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager) -> None: """Performs a single iteration of the orchestration loop. Embodies the core functionality of the main orchestration loop that scans MLMD pipeline execution states, generates and enqueues the tasks to be performed. Args: mlmd_handle: A handle to the MLMD db. task_queue: A `TaskQueue` instance into which any tasks will be enqueued. service_job_manager: A `ServiceJobManager` instance for handling service jobs. Raises: status_lib.StatusNotOkError: If error generating tasks. """ pipeline_states = _get_pipeline_states(mlmd_handle) if not pipeline_states: logging.info('No active pipelines to run.') return active_pipeline_states = [] stop_initiated_pipeline_states = [] update_initiated_pipeline_states = [] for pipeline_state in pipeline_states: with pipeline_state: if pipeline_state.is_stop_initiated(): stop_initiated_pipeline_states.append(pipeline_state) elif pipeline_state.is_update_initiated(): update_initiated_pipeline_states.append(pipeline_state) elif pipeline_state.is_active(): active_pipeline_states.append(pipeline_state) else: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message=( f'Found pipeline (uid: {pipeline_state.pipeline_uid}) ' f'which is neither active nor stop-initiated.')) for pipeline_state in stop_initiated_pipeline_states: logging.info('Orchestrating stop-initiated pipeline: %s', pipeline_state.pipeline_uid) _orchestrate_stop_initiated_pipeline(mlmd_handle, task_queue, service_job_manager, pipeline_state) for pipeline_state in update_initiated_pipeline_states: logging.info('Orchestrating update-initiated pipeline: %s', pipeline_state.pipeline_uid) _orchestrate_update_initiated_pipeline(mlmd_handle, task_queue, service_job_manager, pipeline_state) for pipeline_state in active_pipeline_states: logging.info('Orchestrating pipeline: %s', pipeline_state.pipeline_uid) _orchestrate_active_pipeline(mlmd_handle, task_queue, service_job_manager, pipeline_state)
def _wrapper(*args, **kwargs): try: return fn(*args, **kwargs) except Exception as e: # pylint: disable=broad-except logging.exception('Error raised by `%s`:', fn.__name__) if isinstance(e, status_lib.StatusNotOkError): raise raise status_lib.StatusNotOkError( code=status_lib.Code.UNKNOWN, message=f'`{fn.__name__}` error: {str(e)}')
def initiate_node_start(self, node_uid: task_lib.NodeUid) -> None: """Updates pipeline state to signal that a node should be started.""" if self.pipeline.execution_mode != pipeline_pb2.Pipeline.ASYNC: raise status_lib.StatusNotOkError( code=status_lib.Code.UNIMPLEMENTED, message='Node can be started only for async pipelines.') if not _is_node_uid_in_pipeline(node_uid, self.pipeline): raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message=( f'Node given by uid {node_uid} does not belong to pipeline ' f'given by uid {self.pipeline_uid}')) if self.execution.custom_properties.pop( _node_stop_initiated_property(node_uid), None) is not None: self.execution.custom_properties.pop( _node_status_code_property(node_uid), None) self.execution.custom_properties.pop( _node_status_msg_property(node_uid), None) self._commit = True
def _get_active_execution( pipeline_uid: task_lib.PipelineUid, executions: List[metadata_store_pb2.Execution] ) -> metadata_store_pb2.Execution: """gets a single active execution from the executions.""" active_executions = [ e for e in executions if execution_lib.is_execution_active(e) ] if not active_executions: raise status_lib.StatusNotOkError( code=status_lib.Code.NOT_FOUND, message=f'No active pipeline with uid {pipeline_uid} to load state.' ) if len(active_executions) > 1: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message= (f'Expected 1 but found {len(active_executions)} active pipeline ' f'executions for pipeline uid: {pipeline_uid}')) return active_executions[0]
def load(cls, mlmd_handle: metadata.Metadata, pipeline_uid: task_lib.PipelineUid, pipeline_run_id: Optional[str] = None) -> 'PipelineView': """Loads pipeline view from MLMD. Args: mlmd_handle: A handle to the MLMD db. pipeline_uid: Uid of the pipeline state to load. pipeline_run_id: Run id of the pipeline for the synchronous pipeline. Returns: A `PipelineView` object. Raises: status_lib.StatusNotOkError: With code=NOT_FOUND if no pipeline with the given pipeline uid exists in MLMD. With code=INTERNAL if more than 1 active execution exists for given pipeline uid when pipeline_run_id is not specified. """ context = mlmd_handle.store.get_context_by_type_and_name( type_name=_ORCHESTRATOR_RESERVED_ID, context_name=orchestrator_context_name(pipeline_uid)) if not context: raise status_lib.StatusNotOkError( code=status_lib.Code.NOT_FOUND, message=f'No pipeline with uid {pipeline_uid} found.') executions = mlmd_handle.store.get_executions_by_context(context.id) if pipeline_run_id is None and executions: execution = _get_latest_execution(executions) return cls(pipeline_uid, context, execution) for execution in executions: if execution.custom_properties[ _PIPELINE_RUN_ID].string_value == pipeline_run_id: return cls(pipeline_uid, context, execution) raise status_lib.StatusNotOkError( code=status_lib.Code.NOT_FOUND, message=f'No pipeline with run_id {pipeline_run_id} found.')
def load_from_orchestrator_context( cls, mlmd_handle: metadata.Metadata, context: metadata_store_pb2.Context) -> 'PipelineState': """Loads pipeline state for active pipeline under given orchestrator context. Args: mlmd_handle: A handle to the MLMD db. context: Pipeline context under which to find the pipeline execution. Returns: A `PipelineState` object. Raises: status_lib.StatusNotOkError: With code=NOT_FOUND if no active pipeline exists for the given context in MLMD. With code=INTERNAL if more than 1 active execution exists for given pipeline uid. """ pipeline_uid = pipeline_uid_from_orchestrator_context(context) active_executions = [ e for e in mlmd_handle.store.get_executions_by_context(context.id) if execution_lib.is_execution_active(e) ] if not active_executions: raise status_lib.StatusNotOkError( code=status_lib.Code.NOT_FOUND, message= f'No active pipeline with uid {pipeline_uid} to load state.') if len(active_executions) > 1: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message= (f'Expected 1 but found {len(active_executions)} active pipeline ' f'executions for pipeline uid: {pipeline_uid}')) return cls(mlmd_handle=mlmd_handle, pipeline_uid=pipeline_uid, context=context, execution=active_executions[0], commit=False)
def new(cls, mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline) -> 'PipelineState': """Creates a `PipelineState` object for a new pipeline. No active pipeline with the same pipeline uid should exist for the call to be successful. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline. Returns: A `PipelineState` object. Raises: status_lib.StatusNotOkError: If a pipeline with same UID already exists. """ pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, context_type_name=_ORCHESTRATOR_RESERVED_ID, context_name=orchestrator_context_name(pipeline_uid)) executions = mlmd_handle.store.get_executions_by_context(context.id) if any(e for e in executions if execution_lib.is_execution_active(e)): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=f'Pipeline with uid {pipeline_uid} already active.') execution = execution_lib.prepare_execution( mlmd_handle, _ORCHESTRATOR_EXECUTION_TYPE, metadata_store_pb2.Execution.NEW, exec_properties={ _PIPELINE_IR: base64.b64encode(pipeline.SerializeToString()).decode('utf-8') }, ) if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: data_types_utils.set_metadata_value( execution.custom_properties[_PIPELINE_RUN_ID], pipeline.runtime_spec.pipeline_run_id.field_value.string_value) execution = execution_lib.put_execution(mlmd_handle, execution, [context]) record_state_change_time() return cls(mlmd_handle=mlmd_handle, pipeline=pipeline, execution_id=execution.id)
def apply_pipeline_update(self) -> None: """Applies pipeline update that was previously initiated.""" self._check_context() updated_pipeline_ir = _get_metadata_value( self._execution.custom_properties.get(_UPDATED_PIPELINE_IR)) if not updated_pipeline_ir: raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message='No updated pipeline IR to apply') data_types_utils.set_metadata_value( self._execution.properties[_PIPELINE_IR], updated_pipeline_ir) del self._execution.custom_properties[_UPDATED_PIPELINE_IR] del self._execution.custom_properties[_UPDATE_OPTIONS] self.pipeline = _base64_decode_pipeline(updated_pipeline_ir)
def initiate_node_stop(self, node_uid: task_lib.NodeUid, status: status_lib.Status) -> None: """Updates pipeline state to signal that a node should be stopped.""" self._check_context() if self.pipeline.execution_mode != pipeline_pb2.Pipeline.ASYNC: raise status_lib.StatusNotOkError( code=status_lib.Code.UNIMPLEMENTED, message='Node can be stopped only for async pipelines.') if not _is_node_uid_in_pipeline(node_uid, self.pipeline): raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message=( f'Node given by uid {node_uid} does not belong to pipeline ' f'given by uid {self.pipeline_uid}')) data_types_utils.set_metadata_value( self._execution.custom_properties[_node_stop_initiated_property( node_uid)], 1) data_types_utils.set_metadata_value( self._execution.custom_properties[_node_status_code_property( node_uid)], int(status.code)) if status.message: data_types_utils.set_metadata_value( self._execution.custom_properties[_node_status_msg_property( node_uid)], status.message)
def node_state_update_context( self, node_uid: task_lib.NodeUid) -> Iterator[NodeState]: """Context manager for updating the node state.""" self._check_context() if not _is_node_uid_in_pipeline(node_uid, self.pipeline): raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message=(f'Node {node_uid} does not belong to the pipeline ' f'{self.pipeline_uid}')) node_states_dict = _get_node_states_dict(self._execution) node_state = node_states_dict.setdefault(node_uid.node_id, NodeState()) old_state = node_state.state yield node_state if old_state != node_state.state: logging.info('Changing node state: %s -> %s; node uid: %s', old_state, node_state.state, node_uid) _save_node_states_dict(self._execution, node_states_dict)
def stop_node(mlmd_handle: metadata.Metadata, node_uid: task_lib.NodeUid, timeout_secs: Optional[float] = None) -> None: """Stops a node. Initiates a node stop operation and waits for the node execution to become inactive. Args: mlmd_handle: A handle to the MLMD db. node_uid: Uid of the node to be stopped. timeout_secs: Amount of time in seconds to wait for node to stop. If `None`, waits indefinitely. Raises: status_lib.StatusNotOkError: Failure to stop the node. """ logging.info('Received request to stop node; node uid: %s', node_uid) with _PIPELINE_OPS_LOCK: with pstate.PipelineState.load( mlmd_handle, node_uid.pipeline_uid) as pipeline_state: nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline) filtered_nodes = [ n for n in nodes if n.node_info.id == node_uid.node_id ] if len(filtered_nodes) != 1: raise status_lib.StatusNotOkError( code=status_lib.Code.INTERNAL, message= (f'`stop_node` operation failed, unable to find node to stop: ' f'{node_uid}')) node = filtered_nodes[0] with pipeline_state.node_state_update_context( node_uid) as node_state: if node_state.is_stoppable(): node_state.update( pstate.NodeState.STOPPING, status_lib.Status( code=status_lib.Code.CANCELLED, message='Cancellation requested by client.')) # Wait until the node is stopped or time out. _wait_for_node_inactivation(pipeline_state, node_uid, timeout_secs=timeout_secs)
def _wait_for_predicate(predicate_fn: Callable[[], bool], waiting_for_desc: str, timeout_secs: Optional[float]) -> None: """Waits for `predicate_fn` to return `True` or until timeout seconds elapse.""" if timeout_secs is None: while not predicate_fn(): time.sleep(_POLLING_INTERVAL_SECS) return polling_interval_secs = min(_POLLING_INTERVAL_SECS, timeout_secs / 4) end_time = time.time() + timeout_secs while end_time - time.time() > 0: if predicate_fn(): return time.sleep(max(0, min(polling_interval_secs, end_time - time.time()))) raise status_lib.StatusNotOkError( code=status_lib.Code.DEADLINE_EXCEEDED, message=( f'Timed out ({timeout_secs} secs) waiting for {waiting_for_desc}.' ))
def load(cls, mlmd_handle: metadata.Metadata, pipeline_uid: task_lib.PipelineUid) -> 'PipelineState': """Loads pipeline state from MLMD. Args: mlmd_handle: A handle to the MLMD db. pipeline_uid: Uid of the pipeline state to load. Returns: A `PipelineState` object. Raises: status_lib.StatusNotOkError: With code=NOT_FOUND if no active pipeline with the given pipeline uid exists in MLMD. With code=INTERNAL if more than 1 active execution exists for given pipeline uid. """ context = mlmd_handle.store.get_context_by_type_and_name( type_name=_ORCHESTRATOR_RESERVED_ID, context_name=orchestrator_context_name(pipeline_uid)) if not context: raise status_lib.StatusNotOkError( code=status_lib.Code.NOT_FOUND, message=f'No pipeline with uid {pipeline_uid} found.') return cls.load_from_orchestrator_context(mlmd_handle, context)
def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline with pipeline_state: assert pipeline_state.is_active() if pipeline_state.get_pipeline_execution_state() != ( metadata_store_pb2.Execution.RUNNING): pipeline_state.set_pipeline_execution_state( metadata_store_pb2.Execution.RUNNING) orchestration_options = pipeline_state.get_orchestration_options() logging.info('Orchestration options: %s', orchestration_options) deadline_secs = orchestration_options.deadline_secs if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC and deadline_secs > 0 and time.time() - pipeline_state.pipeline_creation_time_secs_since_epoch() > deadline_secs): logging.error( 'Aborting pipeline due to exceeding deadline (%s secs); ' 'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid) pipeline_state.initiate_stop( status_lib.Status( code=status_lib.Code.DEADLINE_EXCEEDED, message=('Pipeline aborted due to exceeding deadline ' f'({deadline_secs} secs)'))) return def _filter_by_state(node_infos: List[_NodeInfo], state_str: str) -> List[_NodeInfo]: return [n for n in node_infos if n.state.state == state_str] node_infos = _get_node_infos(pipeline_state) stopping_node_infos = _filter_by_state(node_infos, pstate.NodeState.STOPPING) # Tracks nodes stopped in the current iteration. stopped_node_infos: List[_NodeInfo] = [] # Create cancellation tasks for nodes in state STOPPING. for node_info in stopping_node_infos: if service_job_manager.is_pure_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node_info.node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node_info.node.node_info.id): if service_job_manager.stop_node_services( pipeline_state, node_info.node.node_info.id): stopped_node_infos.append(node_info) else: stopped_node_infos.append(node_info) # Change the state of stopped nodes from STOPPING to STOPPED. if stopped_node_infos: with pipeline_state: for node_info in stopped_node_infos: node_uid = task_lib.NodeUid.from_pipeline_node( pipeline, node_info.node) with pipeline_state.node_state_update_context( node_uid) as node_state: node_state.update(pstate.NodeState.STOPPED, node_state.status) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager, fail_fast=orchestration_options.fail_fast) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, task_queue.contains_task_id, service_job_manager) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate(pipeline_state) with pipeline_state: # Handle all the UpdateNodeStateTasks by updating node states. for task in tasks: if task_lib.is_update_node_state_task(task): task = typing.cast(task_lib.UpdateNodeStateTask, task) with pipeline_state.node_state_update_context( task.node_uid) as node_state: node_state.update(task.state, task.status) tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)] # If there are still nodes in state STARTING, change them to STARTED. for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline): node_uid = task_lib.NodeUid.from_pipeline_node( pipeline_state.pipeline, node) with pipeline_state.node_state_update_context( node_uid) as node_state: if node_state.state == pstate.NodeState.STARTING: node_state.update(pstate.NodeState.STARTED) for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)
def _orchestrate_active_pipeline( mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue, service_job_manager: service_jobs.ServiceJobManager, pipeline_state: pstate.PipelineState) -> None: """Orchestrates active pipeline.""" pipeline = pipeline_state.pipeline execution = pipeline_state.execution assert execution.last_known_state in (metadata_store_pb2.Execution.NEW, metadata_store_pb2.Execution.RUNNING) if execution.last_known_state != metadata_store_pb2.Execution.RUNNING: updated_execution = copy.deepcopy(execution) updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING mlmd_handle.store.put_executions([updated_execution]) # Initialize task generator for the pipeline. if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, service_job_manager) elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC: # Create cancellation tasks for stop-initiated nodes if necessary. stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state) for node in stop_initiated_nodes: if service_job_manager.is_pure_service_node( pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node, task_queue): pass elif service_job_manager.is_mixed_service_node( pipeline_state, node.node_info.id): service_job_manager.stop_node_services(pipeline_state, node.node_info.id) generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator( mlmd_handle, pipeline_state, task_queue.contains_task_id, service_job_manager, set(n.node_info.id for n in stop_initiated_nodes)) else: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC and ASYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) tasks = generator.generate() with pipeline_state: for task in tasks: if task_lib.is_exec_node_task(task): task = typing.cast(task_lib.ExecNodeTask, task) task_queue.enqueue(task) elif task_lib.is_finalize_node_task(task): assert pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC task = typing.cast(task_lib.FinalizeNodeTask, task) pipeline_state.initiate_node_stop(task.node_uid, task.status) else: assert task_lib.is_finalize_pipeline_task(task) assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC assert len(tasks) == 1 task = typing.cast(task_lib.FinalizePipelineTask, task) if task.status.code == status_lib.Code.OK: logging.info('Pipeline run successful; pipeline uid: %s', pipeline_state.pipeline_uid) else: logging.info('Pipeline run failed; pipeline uid: %s', pipeline_state.pipeline_uid) pipeline_state.initiate_stop(task.status)
def fn2(): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message='test error 2')
def new( cls, mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None, ) -> 'PipelineState': """Creates a `PipelineState` object for a new pipeline. No active pipeline with the same pipeline uid should exist for the call to be successful. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline. pipeline_run_metadata: Pipeline run metadata. Returns: A `PipelineState` object. Raises: status_lib.StatusNotOkError: If a pipeline with same UID already exists. """ pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, context_type_name=_ORCHESTRATOR_RESERVED_ID, context_name=orchestrator_context_name(pipeline_uid)) executions = mlmd_handle.store.get_executions_by_context(context.id) if any(e for e in executions if execution_lib.is_execution_active(e)): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=f'Pipeline with uid {pipeline_uid} already active.') exec_properties = {_PIPELINE_IR: _base64_encode(pipeline)} if pipeline_run_metadata: exec_properties[_PIPELINE_RUN_METADATA] = json_utils.dumps( pipeline_run_metadata) execution = execution_lib.prepare_execution( mlmd_handle, _ORCHESTRATOR_EXECUTION_TYPE, metadata_store_pb2.Execution.NEW, exec_properties=exec_properties) if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: data_types_utils.set_metadata_value( execution.custom_properties[_PIPELINE_RUN_ID], pipeline.runtime_spec.pipeline_run_id.field_value.string_value) # Set the node state to COMPLETE for any nodes that are marked to be # skipped in a partial pipeline run. node_states_dict = {} for node in get_all_pipeline_nodes(pipeline): if node.execution_options.HasField('skip'): logging.info('Node %s is skipped in this partial run.', node.node_info.id) node_states_dict[node.node_info.id] = NodeState( state=NodeState.COMPLETE) if node_states_dict: _save_node_states_dict(execution, node_states_dict) execution = execution_lib.put_execution(mlmd_handle, execution, [context]) record_state_change_time() return cls(mlmd_handle=mlmd_handle, pipeline=pipeline, execution_id=execution.id)
def resume_pipeline(mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline) -> pstate.PipelineState: """Resumes a pipeline run from previously failed nodes. Upon success, MLMD is updated to signal that the pipeline must be started. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline to resume. Returns: The `PipelineState` object upon success. Raises: status_lib.StatusNotOkError: Failure to resume pipeline. With code `ALREADY_EXISTS` if a pipeline is already running. With code `status_lib.Code.FAILED_PRECONDITION` if a previous pipeline run is not found for resuming. """ logging.info('Received request to resume pipeline; pipeline uid: %s', task_lib.PipelineUid.from_pipeline(pipeline)) if pipeline.execution_mode != pipeline_pb2.Pipeline.SYNC: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message=( f'Only SYNC pipeline execution modes supported; ' f'found pipeline with execution mode: {pipeline.execution_mode}' )) latest_pipeline_view = None pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) views = pstate.PipelineView.load_all(mlmd_handle, pipeline_uid) for view in views: execution = view.execution if execution_lib.is_execution_active(execution): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=( f'Can not resume pipeline. An active pipeline is already ' f'running with uid {pipeline_uid}.')) if (not latest_pipeline_view or execution.create_time_since_epoch > latest_pipeline_view.execution.create_time_since_epoch): latest_pipeline_view = view if not latest_pipeline_view: raise status_lib.StatusNotOkError( code=status_lib.Code.NOT_FOUND, message='Pipeline failed to resume. No previous pipeline run found.' ) if latest_pipeline_view.pipeline.execution_mode != pipeline_pb2.Pipeline.SYNC: raise status_lib.StatusNotOkError( code=status_lib.Code.FAILED_PRECONDITION, message= (f'Only SYNC pipeline execution modes supported; previous pipeline ' f'run has execution mode: ' f'{latest_pipeline_view.pipeline.execution_mode}')) # Get succeeded nodes in latest pipeline run. latest_pipeline_node_states = latest_pipeline_view.get_node_states_dict() previously_succeeded_nodes = [] for node, node_state in latest_pipeline_node_states.items(): if node_state.is_success(): previously_succeeded_nodes.append(node) pipeline_nodes = [ node.node_info.id for node in pstate.get_all_pipeline_nodes(pipeline) ] latest_pipeline_snapshot_settings = pipeline_pb2.SnapshotSettings() latest_pipeline_snapshot_settings.latest_pipeline_run_strategy.SetInParent( ) partial_run_option = pipeline_pb2.PartialRun( from_nodes=pipeline_nodes, to_nodes=pipeline_nodes, skip_nodes=previously_succeeded_nodes, snapshot_settings=latest_pipeline_snapshot_settings) return initiate_pipeline_start(mlmd_handle, pipeline, partial_run_option=partial_run_option)