def testPutParentContextIfNotExists(self): with metadata.Metadata(connection_config=self._connection_config) as m: parent_context = context_lib.register_context_if_not_exists( metadata_handler=m, context_type_name='my_context_type', context_name='parent_context_name') child_context = context_lib.register_context_if_not_exists( metadata_handler=m, context_type_name='my_context_type', context_name='child_context_name') context_lib.put_parent_context_if_not_exists( m, parent_id=parent_context.id, child_id=child_context.id) # Duplicated call should succeed. context_lib.put_parent_context_if_not_exists( m, parent_id=parent_context.id, child_id=child_context.id)
def testRegisterContextByTypeAndName(self): with metadata.Metadata(connection_config=self._connection_config) as m: context_lib.register_context_if_not_exists( metadata_handler=m, context_type_name='my_context_type', context_name='my_context') # Duplicated call should succeed. context = context_lib.register_context_if_not_exists( metadata_handler=m, context_type_name='my_context_type', context_name='my_context') self.assertProtoEquals( """ id: 1 name: 'my_context_type' """, m.store.get_context_type('my_context_type')) self.assertEqual( context, m.store.get_context_by_type_and_name('my_context_type', 'my_context'))
def new(cls, mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline) -> 'PipelineState': """Creates a `PipelineState` object for a new pipeline. No active pipeline with the same pipeline uid should exist for the call to be successful. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline. Returns: A `PipelineState` object. Raises: status_lib.StatusNotOkError: If a pipeline with same UID already exists. """ pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, context_type_name=_ORCHESTRATOR_RESERVED_ID, context_name=orchestrator_context_name(pipeline_uid)) executions = mlmd_handle.store.get_executions_by_context(context.id) if any(e for e in executions if execution_lib.is_execution_active(e)): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=f'Pipeline with uid {pipeline_uid} already active.') execution = execution_lib.prepare_execution( mlmd_handle, _ORCHESTRATOR_EXECUTION_TYPE, metadata_store_pb2.Execution.NEW, exec_properties={ _PIPELINE_IR: base64.b64encode(pipeline.SerializeToString()).decode('utf-8') }, ) if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: data_types_utils.set_metadata_value( execution.custom_properties[_PIPELINE_RUN_ID], pipeline.runtime_spec.pipeline_run_id.field_value.string_value) execution = execution_lib.put_execution(mlmd_handle, execution, [context]) record_state_change_time() return cls(mlmd_handle=mlmd_handle, pipeline=pipeline, execution_id=execution.id)
def testGetCachedOutputArtifactsForNodesWithNoOuput(self): with metadata.Metadata(connection_config=self._connection_config) as m: cache_context = context_lib.register_context_if_not_exists( m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key') cached_output = cache_utils.get_cached_outputs(m, cache_context) # No succeed execution is associate with this context yet, so the cached # output is None. self.assertIsNone(cached_output) execution_one = execution_publish_utils.register_execution( m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context]) execution_publish_utils.publish_succeeded_execution( m, execution_one.id, [cache_context]) cached_output = cache_utils.get_cached_outputs(m, cache_context) # A succeed execution is associate with this context, so the cached # output is not None but an empty dict. self.assertIsNotNone(cached_output) self.assertEmpty(cached_output)
def initiate_pipeline_start( mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline) -> metadata_store_pb2.Execution: """Initiates a pipeline start operation. Upon success, MLMD is updated to signal that the given pipeline must be started. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline to start. Returns: The pipeline-level MLMD execution proto upon success. Raises: status_lib.StatusNotOkError: Failure to initiate pipeline start or if execution is not inactive after waiting `timeout_secs`. """ pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, context_type_name=_ORCHESTRATOR_RESERVED_ID, context_name=_orchestrator_context_name(pipeline_uid)) executions = mlmd_handle.store.get_executions_by_context(context.id) if any(e for e in executions if execution_lib.is_execution_active(e)): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=f'Pipeline with uid {pipeline_uid} already started.') execution = execution_lib.prepare_execution( mlmd_handle, _ORCHESTRATOR_EXECUTION_TYPE, metadata_store_pb2.Execution.NEW, exec_properties={ _PIPELINE_IR: base64.b64encode(pipeline.SerializeToString()).decode('utf-8') }) execution = execution_lib.put_execution(mlmd_handle, execution, [context]) logging.info('Registered execution (id: %s) for the pipeline with uid: %s', execution.id, pipeline_uid) return execution
def _get_pipeline_run_context( self, run_id: str, register_if_not_found: bool = False) -> metadata_store_pb2.Context: """Gets the pipeline_run_context for a given pipeline run id. When called, it will first attempt to get the pipeline run context from the in-memory cache. If not found there, it will raise LookupError unless `register_if_not_found` is set to True. If `register_if_not_found` is set to True, this method will register the pipeline_run_context in MLMD, add it to the in-memory cache, and return the pipeline_run_context. Args: run_id: The pipeline_run_id whose Context to query. register_if_not_found: If set to True, it will register the pipeline_run_id in MLMD if the pipeline_run_id cannot be found in MLMD. If set to False, it will raise LookupError. Defaults to False. Returns: The requested pipeline run Context. Raises: LookupError: If register_if_not_found is not set to True, and the pipeline_run_id cannot be found in MLMD. """ if run_id not in self._pipeline_run_contexts: if register_if_not_found: pipeline_run_context = context_lib.register_context_if_not_exists( self._mlmd, context_type_name=constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, context_name=run_id) self._pipeline_run_contexts[run_id] = pipeline_run_context else: raise LookupError( f'pipeline_run_id {run_id} not found in MLMD.') return self._pipeline_run_contexts[run_id]
def testGetCachedOutputArtifacts(self): # Output artifacts that will be used by the first execution with the same # cache key. output_model_one = standard_artifacts.Model() output_model_one.uri = 'model_one' output_model_two = standard_artifacts.Model() output_model_two.uri = 'model_two' output_example_one = standard_artifacts.Examples() output_example_one.uri = 'example_one' # Output artifacts that will be used by the second execution with the same # cache key. output_model_three = standard_artifacts.Model() output_model_three.uri = 'model_three' output_model_four = standard_artifacts.Model() output_model_four.uri = 'model_four' output_example_two = standard_artifacts.Examples() output_example_two.uri = 'example_two' output_models_key = 'output_models' output_examples_key = 'output_examples' with metadata.Metadata(connection_config=self._connection_config) as m: cache_context = context_lib.register_context_if_not_exists( m, context_lib.CONTEXT_TYPE_EXECUTION_CACHE, 'cache_key') execution_one = execution_publish_utils.register_execution( m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context]) execution_publish_utils.publish_succeeded_execution( m, execution_one.id, [cache_context], output_artifacts={ output_models_key: [output_model_one, output_model_two], output_examples_key: [output_example_one] }) execution_two = execution_publish_utils.register_execution( m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context]) execution_publish_utils.publish_succeeded_execution( m, execution_two.id, [cache_context], output_artifacts={ output_models_key: [output_model_three, output_model_four], output_examples_key: [output_example_two] }) # The cached output got should be the artifacts produced by the most # recent execution under the given cache context. cached_output = cache_utils.get_cached_outputs(m, cache_context) self.assertLen(cached_output, 2) self.assertLen(cached_output[output_models_key], 2) self.assertLen(cached_output[output_examples_key], 1) self.assertProtoPartiallyEquals( cached_output[output_models_key][0].mlmd_artifact, output_model_three.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) self.assertProtoPartiallyEquals( cached_output[output_models_key][1].mlmd_artifact, output_model_four.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) self.assertProtoPartiallyEquals( cached_output[output_examples_key][0].mlmd_artifact, output_example_two.mlmd_artifact, ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ])
def get_cache_context( metadata_handler: metadata.Metadata, pipeline_node: pipeline_pb2.PipelineNode, pipeline_info: pipeline_pb2.PipelineInfo, executor_spec: Optional[message.Message] = None, input_artifacts: Optional[Mapping[Text, Sequence[types.Artifact]]] = None, output_artifacts: Optional[Mapping[Text, Sequence[types.Artifact]]] = None, parameters: Optional[Mapping[Text, Any]] = None) -> metadata_store_pb2.Context: """Gets cache context for a potential node execution. The cache key is generated by applying SHA-256 hashing function on: - Serialized pipeline info. - Serialized node_info of the PipelineNode. - Serialized executor spec - Serialized input artifacts if any. - Serialized output artifacts if any. The uri was removed during the process. - Serialized parameters if any. - Serialized module file content if module file is present in parameters. Args: metadata_handler: A handler to access MLMD store. pipeline_node: A pipeline_pb2.PipelineNode instance to represent the node. pipeline_info: Information of the pipeline. executor_spec: A proto message representing the executor specification. input_artifacts: Input artifacts of the potential execution. The order of the artifacts under a key matters when calculating the cache key. output_artifacts: Output artifacts skeleton of the potential execution. The order of the artifadts under a key matters when calculating the cache key. parameters: Parameters of the potential execution. Returns: A metadata_store_pb2.Context for the cache key. """ h = hashlib.sha256() h.update(pipeline_info.SerializeToString(deterministic=True)) h.update(pipeline_node.node_info.SerializeToString(deterministic=True)) if executor_spec: h.update(executor_spec.SerializeToString(deterministic=True)) for key in sorted(input_artifacts or {}): h.update(key.encode()) for artifact in input_artifacts[key]: h.update( artifact.mlmd_artifact.SerializeToString(deterministic=True)) for key in sorted(output_artifacts or {}): h.update(key.encode()) for artifact in output_artifacts[key]: stateless_artifact = copy.deepcopy(artifact) # Output uri and name should not be taken into consideration as cache key. stateless_artifact.uri = '' stateless_artifact.name = '' h.update( stateless_artifact.mlmd_artifact.SerializeToString( deterministic=True)) parameters = parameters or {} for key, value in sorted(parameters.items()): h.update(key.encode()) h.update(str(value).encode()) # Special treatment for module files as they will be used as part of the logic # for processing. Currently this pattern is employeed by Trainer and # Transform. if ('module_file' in parameters and parameters['module_file'] and fileio.exists(parameters['module_file'])): with fileio.open(parameters['module_file'], 'r') as f: h.update(f.read().encode()) return context_lib.register_context_if_not_exists( metadata_handler=metadata_handler, context_type_name=context_lib.CONTEXT_TYPE_EXECUTION_CACHE, context_name=h.hexdigest())
def new( cls, mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None, ) -> 'PipelineState': """Creates a `PipelineState` object for a new pipeline. No active pipeline with the same pipeline uid should exist for the call to be successful. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline. pipeline_run_metadata: Pipeline run metadata. Returns: A `PipelineState` object. Raises: status_lib.StatusNotOkError: If a pipeline with same UID already exists. """ pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, context_type_name=_ORCHESTRATOR_RESERVED_ID, context_name=orchestrator_context_name(pipeline_uid)) executions = mlmd_handle.store.get_executions_by_context(context.id) if any(e for e in executions if execution_lib.is_execution_active(e)): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=f'Pipeline with uid {pipeline_uid} already active.') exec_properties = {_PIPELINE_IR: _base64_encode(pipeline)} if pipeline_run_metadata: exec_properties[_PIPELINE_RUN_METADATA] = json_utils.dumps( pipeline_run_metadata) execution = execution_lib.prepare_execution( mlmd_handle, _ORCHESTRATOR_EXECUTION_TYPE, metadata_store_pb2.Execution.NEW, exec_properties=exec_properties) if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: data_types_utils.set_metadata_value( execution.custom_properties[_PIPELINE_RUN_ID], pipeline.runtime_spec.pipeline_run_id.field_value.string_value) # Set the node state to COMPLETE for any nodes that are marked to be # skipped in a partial pipeline run. node_states_dict = {} for node in get_all_pipeline_nodes(pipeline): if node.execution_options.HasField('skip'): logging.info('Node %s is skipped in this partial run.', node.node_info.id) node_states_dict[node.node_info.id] = NodeState( state=NodeState.COMPLETE) if node_states_dict: _save_node_states_dict(execution, node_states_dict) execution = execution_lib.put_execution(mlmd_handle, execution, [context]) record_state_change_time() return cls(mlmd_handle=mlmd_handle, pipeline=pipeline, execution_id=execution.id)