def new(cls, mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline) -> 'PipelineState': """Creates a `PipelineState` object for a new pipeline. No active pipeline with the same pipeline uid should exist for the call to be successful. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline. Returns: A `PipelineState` object. Raises: status_lib.StatusNotOkError: If a pipeline with same UID already exists. """ pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, context_type_name=_ORCHESTRATOR_RESERVED_ID, context_name=orchestrator_context_name(pipeline_uid)) executions = mlmd_handle.store.get_executions_by_context(context.id) if any(e for e in executions if execution_lib.is_execution_active(e)): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=f'Pipeline with uid {pipeline_uid} already active.') execution = execution_lib.prepare_execution( mlmd_handle, _ORCHESTRATOR_EXECUTION_TYPE, metadata_store_pb2.Execution.NEW, exec_properties={ _PIPELINE_IR: base64.b64encode(pipeline.SerializeToString()).decode('utf-8') }, ) if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: data_types_utils.set_metadata_value( execution.custom_properties[_PIPELINE_RUN_ID], pipeline.runtime_spec.pipeline_run_id.field_value.string_value) execution = execution_lib.put_execution(mlmd_handle, execution, [context]) record_state_change_time() return cls(mlmd_handle=mlmd_handle, pipeline=pipeline, execution_id=execution.id)
def initiate_pipeline_start( mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline) -> metadata_store_pb2.Execution: """Initiates a pipeline start operation. Upon success, MLMD is updated to signal that the given pipeline must be started. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline to start. Returns: The pipeline-level MLMD execution proto upon success. Raises: status_lib.StatusNotOkError: Failure to initiate pipeline start or if execution is not inactive after waiting `timeout_secs`. """ pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, context_type_name=_ORCHESTRATOR_RESERVED_ID, context_name=_orchestrator_context_name(pipeline_uid)) executions = mlmd_handle.store.get_executions_by_context(context.id) if any(e for e in executions if execution_lib.is_execution_active(e)): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=f'Pipeline with uid {pipeline_uid} already started.') execution = execution_lib.prepare_execution( mlmd_handle, _ORCHESTRATOR_EXECUTION_TYPE, metadata_store_pb2.Execution.NEW, exec_properties={ _PIPELINE_IR: base64.b64encode(pipeline.SerializeToString()).decode('utf-8') }) execution = execution_lib.put_execution(mlmd_handle, execution, [context]) logging.info('Registered execution (id: %s) for the pipeline with uid: %s', execution.id, pipeline_uid) return execution
def _fix_deployment_config( input_pipeline: p_pb2.Pipeline, node_ids_to_keep: Collection[str]) -> Union[any_pb2.Any, None]: """Filter per-node deployment configs. Cast deployment configs from Any proto to IntermediateDeploymentConfig. Take all three per-node fields and filter out the nodes using node_ids_to_keep. This works because those fields don't contain references to other nodes. Args: input_pipeline: The input Pipeline IR proto. node_ids_to_keep: Set of node_ids to keep. Returns: If the deployment_config field is set in the input_pipeline, this would output the deployment config with filtered per-node configs, then cast into an Any proto. If the deployment_config field is unset in the input_pipeline, then this function would return None. """ if not input_pipeline.HasField('deployment_config'): return None deployment_config = p_pb2.IntermediateDeploymentConfig() input_pipeline.deployment_config.Unpack(deployment_config) def _fix_per_node_config(config_map: MutableMapping[str, Any]): # We have to make two passes because we cannot modify the dictionary while # iterating over it. node_ids_to_delete = [ node_id for node_id in config_map if node_id not in node_ids_to_keep ] for node_id_to_delete in node_ids_to_delete: del config_map[node_id_to_delete] _fix_per_node_config(deployment_config.executor_specs) _fix_per_node_config(deployment_config.custom_driver_specs) _fix_per_node_config(deployment_config.node_level_platform_configs) result = any_pb2.Any() result.Pack(deployment_config) return result