def publish_internal_execution( metadata_handler: metadata.Metadata, contexts: Sequence[metadata_store_pb2.Context], execution_id: int, output_artifacts: Optional[MutableMapping[str, Sequence[types.Artifact]]] = None ) -> None: """Marks an exeisting execution as as success and links its output to an INTERNAL_OUTPUT event. Args: metadata_handler: A handler to access MLMD. contexts: MLMD contexts to associated with the execution. execution_id: The id of the execution. output_artifacts: Output artifacts of the execution. Each artifact will be linked with the execution through an event with type INTERNAL_OUTPUT. """ [execution] = metadata_handler.store.get_executions_by_id([execution_id]) execution.last_known_state = metadata_store_pb2.Execution.COMPLETE execution_lib.put_execution( metadata_handler, execution, contexts, output_artifacts=output_artifacts, output_event_type=metadata_store_pb2.Event.INTERNAL_OUTPUT)
def publish_cached_execution( metadata_handler: metadata.Metadata, contexts: Sequence[metadata_store_pb2.Context], execution_id: int, output_artifacts: Optional[MutableMapping[str, Sequence[types.Artifact]]] = None, ) -> None: """Marks an existing execution as using cached outputs from a previous execution. Args: metadata_handler: A handler to access MLMD. contexts: MLMD contexts to associated with the execution. execution_id: The id of the execution. output_artifacts: Output artifacts of the execution. Each artifact will be linked with the execution through an event with type OUTPUT. """ [execution] = metadata_handler.store.get_executions_by_id([execution_id]) execution.last_known_state = metadata_store_pb2.Execution.CACHED execution_lib.put_execution( metadata_handler, execution, contexts, input_artifacts=None, output_artifacts=output_artifacts)
def publish_failed_execution(metadata_handler: metadata.Metadata, contexts: Sequence[metadata_store_pb2.Context], execution_id: int) -> None: """Marks an existing execution as failed. Args: metadata_handler: A handler to access MLMD. contexts: MLMD contexts to associated with the execution. execution_id: The id of the execution. """ [execution] = metadata_handler.store.get_executions_by_id([execution_id]) execution.last_known_state = metadata_store_pb2.Execution.FAILED execution_lib.put_execution(metadata_handler, execution, contexts)
def register_execution( metadata_handler: metadata.Metadata, execution_type: metadata_store_pb2.ExecutionType, contexts: Sequence[metadata_store_pb2.Context], input_artifacts: Optional[MutableMapping[str, Sequence[types.Artifact]]] = None, exec_properties: Optional[Mapping[str, types.Property]] = None, ) -> metadata_store_pb2.Execution: """Registers a new execution in MLMD. Along with the execution: - the input artifacts will be linked to the execution. - the contexts will be linked to both the execution and its input artifacts. Args: metadata_handler: A handler to access MLMD. execution_type: The type of the execution. contexts: MLMD contexts to associated with the execution. input_artifacts: Input artifacts of the execution. Each artifact will be linked with the execution through an event. exec_properties: Execution properties. Will be attached to the execution. Returns: An MLMD execution that is registered in MLMD, with id populated. """ execution = execution_lib.prepare_execution( metadata_handler, execution_type, metadata_store_pb2.Execution.RUNNING, exec_properties) return execution_lib.put_execution(metadata_handler, execution, contexts, input_artifacts=input_artifacts)
def testPutExecutionGraph(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Prepares an input artifact. The artifact should be registered in MLMD # before the put_execution call. input_example = standard_artifacts.Examples() input_example.uri = 'example' input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id [input_example.id ] = m.store.put_artifacts([input_example.mlmd_artifact]) # Prepares an output artifact. output_model = standard_artifacts.Model() output_model.uri = 'model' execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), exec_properties={ 'p1': 1, 'p2': '2' }, state=metadata_store_pb2.Execution.COMPLETE) contexts = self._generate_contexts(m) execution = execution_lib.put_execution( m, execution, contexts, input_artifacts={'example': [input_example]}, output_artifacts={'model': [output_model]}) self.assertProtoPartiallyEquals( output_model.mlmd_artifact, m.store.get_artifacts_by_id([output_model.id])[0], ignored_fields=[ 'create_time_since_epoch', 'last_update_time_since_epoch' ]) # Verifies edges between artifacts and execution. [input_event ] = m.store.get_events_by_artifact_ids([input_example.id]) self.assertEqual(input_event.execution_id, execution.id) self.assertEqual(input_event.type, metadata_store_pb2.Event.INPUT) [output_event ] = m.store.get_events_by_artifact_ids([output_model.id]) self.assertEqual(output_event.execution_id, execution.id) self.assertEqual(output_event.type, metadata_store_pb2.Event.OUTPUT) # Verifies edges connecting contexts and {artifacts, execution}. context_ids = [context.id for context in contexts] self.assertCountEqual([ c.id for c in m.store.get_contexts_by_artifact(input_example.id) ], context_ids) self.assertCountEqual([ c.id for c in m.store.get_contexts_by_artifact(output_model.id) ], context_ids) self.assertCountEqual([ c.id for c in m.store.get_contexts_by_execution(execution.id) ], context_ids)
def commit(self) -> None: """Commits pipeline state to MLMD if there are any mutations.""" if self._commit: self.execution = execution_lib.put_execution( self.mlmd_handle, self.execution, [self.context]) logging.info( 'Committed execution (id: %s) for pipeline with uid: %s', self.execution.id, self.pipeline_uid) self._commit = False
def publish_failed_execution( metadata_handler: metadata.Metadata, contexts: Sequence[metadata_store_pb2.Context], execution_id: int, executor_output: Optional[execution_result_pb2.ExecutorOutput] = None ) -> None: """Marks an existing execution as failed. Args: metadata_handler: A handler to access MLMD. contexts: MLMD contexts to associated with the execution. execution_id: The id of the execution. executor_output: The output of executor. """ [execution] = metadata_handler.store.get_executions_by_id([execution_id]) execution.last_known_state = metadata_store_pb2.Execution.FAILED _set_execution_result_if_not_empty(executor_output, execution) execution_lib.put_execution(metadata_handler, execution, contexts)
def testGetExecutionsAssociatedWithAllContexts(self): with metadata.Metadata(connection_config=self._connection_config) as m: contexts = self._generate_contexts(m) self.assertLen(contexts, 2) # Create 2 executions and associate with one context each. execution1 = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), metadata_store_pb2.Execution.RUNNING) execution1 = execution_lib.put_execution(m, execution1, [contexts[0]]) execution2 = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), metadata_store_pb2.Execution.COMPLETE) execution2 = execution_lib.put_execution(m, execution2, [contexts[1]]) # Create another execution and associate with both contexts. execution3 = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), metadata_store_pb2.Execution.NEW) execution3 = execution_lib.put_execution(m, execution3, contexts) # Verify that the right executions are returned. with self.subTest(for_contexts=(0, )): executions = execution_lib.get_executions_associated_with_all_contexts( m, [contexts[0]]) self.assertCountEqual([execution1.id, execution3.id], [e.id for e in executions]) with self.subTest(for_contexts=(1, )): executions = execution_lib.get_executions_associated_with_all_contexts( m, [contexts[1]]) self.assertCountEqual([execution2.id, execution3.id], [e.id for e in executions]) with self.subTest(for_contexts=(0, 1)): executions = execution_lib.get_executions_associated_with_all_contexts( m, contexts) self.assertCountEqual([execution3.id], [e.id for e in executions])
def new(cls, mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline) -> 'PipelineState': """Creates a `PipelineState` object for a new pipeline. No active pipeline with the same pipeline uid should exist for the call to be successful. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline. Returns: A `PipelineState` object. Raises: status_lib.StatusNotOkError: If a pipeline with same UID already exists. """ pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, context_type_name=_ORCHESTRATOR_RESERVED_ID, context_name=orchestrator_context_name(pipeline_uid)) executions = mlmd_handle.store.get_executions_by_context(context.id) if any(e for e in executions if execution_lib.is_execution_active(e)): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=f'Pipeline with uid {pipeline_uid} already active.') execution = execution_lib.prepare_execution( mlmd_handle, _ORCHESTRATOR_EXECUTION_TYPE, metadata_store_pb2.Execution.NEW, exec_properties={ _PIPELINE_IR: base64.b64encode(pipeline.SerializeToString()).decode('utf-8') }, ) if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: data_types_utils.set_metadata_value( execution.custom_properties[_PIPELINE_RUN_ID], pipeline.runtime_spec.pipeline_run_id.field_value.string_value) execution = execution_lib.put_execution(mlmd_handle, execution, [context]) record_state_change_time() return cls(mlmd_handle=mlmd_handle, pipeline=pipeline, execution_id=execution.id)
def testGetArtifactsDict(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Create and shuffle a few artifacts. The shuffled order should be # retained in the output of `execution_lib.get_artifacts_dict`. input_examples = [] for i in range(10): input_example = standard_artifacts.Examples() input_example.uri = 'example{}'.format(i) input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id input_examples.append(input_example) random.shuffle(input_examples) output_models = [] for i in range(8): output_model = standard_artifacts.Model() output_model.uri = 'model{}'.format(i) output_model.type_id = common_utils.register_type_if_not_exist( m, output_model.artifact_type).id output_models.append(output_model) random.shuffle(output_models) m.store.put_artifacts([ a.mlmd_artifact for a in itertools.chain(input_examples, output_models) ]) execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), state=metadata_store_pb2.Execution.RUNNING) contexts = self._generate_contexts(m) input_artifacts_dict = {'examples': input_examples} output_artifacts_dict = {'model': output_models} execution = execution_lib.put_execution( m, execution, contexts, input_artifacts=input_artifacts_dict, output_artifacts=output_artifacts_dict) # Verify that the same artifacts are returned in the correct order. artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, metadata_store_pb2.Event.INPUT) self.assertCountEqual(['examples'], list(artifacts_dict.keys())) self.assertEqual([ex.uri for ex in input_examples], [a.uri for a in artifacts_dict['examples']]) artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, metadata_store_pb2.Event.OUTPUT) self.assertCountEqual(['model'], list(artifacts_dict.keys())) self.assertEqual([model.uri for model in output_models], [a.uri for a in artifacts_dict['model']])
def initiate_pipeline_start( mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline) -> metadata_store_pb2.Execution: """Initiates a pipeline start operation. Upon success, MLMD is updated to signal that the given pipeline must be started. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline to start. Returns: The pipeline-level MLMD execution proto upon success. Raises: status_lib.StatusNotOkError: Failure to initiate pipeline start or if execution is not inactive after waiting `timeout_secs`. """ pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, context_type_name=_ORCHESTRATOR_RESERVED_ID, context_name=_orchestrator_context_name(pipeline_uid)) executions = mlmd_handle.store.get_executions_by_context(context.id) if any(e for e in executions if execution_lib.is_execution_active(e)): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=f'Pipeline with uid {pipeline_uid} already started.') execution = execution_lib.prepare_execution( mlmd_handle, _ORCHESTRATOR_EXECUTION_TYPE, metadata_store_pb2.Execution.NEW, exec_properties={ _PIPELINE_IR: base64.b64encode(pipeline.SerializeToString()).decode('utf-8') }) execution = execution_lib.put_execution(mlmd_handle, execution, [context]) logging.info('Registered execution (id: %s) for the pipeline with uid: %s', execution.id, pipeline_uid) return execution
def testGetArtifactIdsForExecutionIdGroupedByEventType(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Register an input and output artifacts in MLMD. input_example = standard_artifacts.Examples() input_example.uri = 'example' input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id output_model = standard_artifacts.Model() output_model.uri = 'model' output_model.type_id = common_utils.register_type_if_not_exist( m, output_model.artifact_type).id [input_example.id, output_model.id] = m.store.put_artifacts( [input_example.mlmd_artifact, output_model.mlmd_artifact]) execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), exec_properties={ 'p1': 1, 'p2': '2' }, state=metadata_store_pb2.Execution.COMPLETE) contexts = self._generate_contexts(m) execution = execution_lib.put_execution( m, execution, contexts, input_artifacts={'example': [input_example]}, output_artifacts={'model': [output_model]}) artifact_ids_by_event_type = ( execution_lib.get_artifact_ids_by_event_type_for_execution_id( m, execution.id)) self.assertDictEqual( { metadata_store_pb2.Event.INPUT: set([input_example.id]), metadata_store_pb2.Event.OUTPUT: set([output_model.id]), }, artifact_ids_by_event_type)
def publish_succeeded_execution( metadata_handler: metadata.Metadata, execution_id: int, contexts: Sequence[metadata_store_pb2.Context], output_artifacts: Optional[MutableMapping[ str, Sequence[types.Artifact]]] = None, executor_output: Optional[execution_result_pb2.ExecutorOutput] = None ) -> Optional[MutableMapping[str, List[types.Artifact]]]: """Marks an existing execution as success. Also publishes the output artifacts produced by the execution. This method will also merge the executor produced info into system generated output artifacts. The `last_know_state` of the execution will be changed to `COMPLETE` and the output artifacts will be marked as `LIVE`. Args: metadata_handler: A handler to access MLMD. execution_id: The id of the execution to mark successful. contexts: MLMD contexts to associated with the execution. output_artifacts: Output artifacts skeleton of the execution, generated by the system. Each artifact will be linked with the execution through an event with type OUTPUT. executor_output: Executor outputs. `executor_output.output_artifacts` will be used to update system-generated output artifacts passed in through `output_artifacts` arg. There are three contraints to the update: 1. The keys in `executor_output.output_artifacts` are expected to be a subset of the system-generated output artifacts dict. 2. An update to a certain key should contains all the artifacts under that key. 3. An update to an artifact should not change the type of the artifact. Returns: The maybe updated output_artifacts, note that only outputs whose key are in executor_output will be updated and others will be untouched. That said, it can be partially updated. Raises: RuntimeError: if the executor output to a output channel is partial. """ output_artifacts = copy.deepcopy(output_artifacts) or {} output_artifacts = cast(MutableMapping[str, List[types.Artifact]], output_artifacts) if executor_output: if not set(executor_output.output_artifacts.keys()).issubset( output_artifacts.keys()): raise RuntimeError( 'Executor output %s contains more keys than output skeleton %s.' % (executor_output, output_artifacts)) for key, artifact_list in output_artifacts.items(): if key not in executor_output.output_artifacts: continue updated_artifact_list = executor_output.output_artifacts[ key].artifacts # We assume the original output dict must include at least one output # artifact and all artifacts in the list share the same type. original_artifact = artifact_list[0] # Update the artifact list with what's in the executor output artifact_list.clear() # TODO(b/175426744): revisit this: # 1) Whether multiple output is needed or not after TFX componets # are upgraded. # 2) If multiple output are needed and is a common practice, should we # use driver instead to create the list of output artifact instead # of letting executor to create them. for proto_artifact in updated_artifact_list: _check_validity(proto_artifact, original_artifact) python_artifact = types.Artifact( original_artifact.artifact_type) python_artifact.set_mlmd_artifact(proto_artifact) artifact_list.append(python_artifact) # Marks output artifacts as LIVE. for artifact_list in output_artifacts.values(): for artifact in artifact_list: artifact.mlmd_artifact.state = metadata_store_pb2.Artifact.LIVE [execution] = metadata_handler.store.get_executions_by_id([execution_id]) execution.last_known_state = metadata_store_pb2.Execution.COMPLETE execution_lib.put_execution(metadata_handler, execution, contexts, output_artifacts=output_artifacts) return output_artifacts
def new( cls, mlmd_handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline, pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None, ) -> 'PipelineState': """Creates a `PipelineState` object for a new pipeline. No active pipeline with the same pipeline uid should exist for the call to be successful. Args: mlmd_handle: A handle to the MLMD db. pipeline: IR of the pipeline. pipeline_run_metadata: Pipeline run metadata. Returns: A `PipelineState` object. Raises: status_lib.StatusNotOkError: If a pipeline with same UID already exists. """ pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) context = context_lib.register_context_if_not_exists( mlmd_handle, context_type_name=_ORCHESTRATOR_RESERVED_ID, context_name=orchestrator_context_name(pipeline_uid)) executions = mlmd_handle.store.get_executions_by_context(context.id) if any(e for e in executions if execution_lib.is_execution_active(e)): raise status_lib.StatusNotOkError( code=status_lib.Code.ALREADY_EXISTS, message=f'Pipeline with uid {pipeline_uid} already active.') exec_properties = {_PIPELINE_IR: _base64_encode(pipeline)} if pipeline_run_metadata: exec_properties[_PIPELINE_RUN_METADATA] = json_utils.dumps( pipeline_run_metadata) execution = execution_lib.prepare_execution( mlmd_handle, _ORCHESTRATOR_EXECUTION_TYPE, metadata_store_pb2.Execution.NEW, exec_properties=exec_properties) if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC: data_types_utils.set_metadata_value( execution.custom_properties[_PIPELINE_RUN_ID], pipeline.runtime_spec.pipeline_run_id.field_value.string_value) # Set the node state to COMPLETE for any nodes that are marked to be # skipped in a partial pipeline run. node_states_dict = {} for node in get_all_pipeline_nodes(pipeline): if node.execution_options.HasField('skip'): logging.info('Node %s is skipped in this partial run.', node.node_info.id) node_states_dict[node.node_info.id] = NodeState( state=NodeState.COMPLETE) if node_states_dict: _save_node_states_dict(execution, node_states_dict) execution = execution_lib.put_execution(mlmd_handle, execution, [context]) record_state_change_time() return cls(mlmd_handle=mlmd_handle, pipeline=pipeline, execution_id=execution.id)
def publish_succeeded_execution( metadata_handler: metadata.Metadata, execution_id: int, contexts: Sequence[metadata_store_pb2.Context], output_artifacts: Optional[MutableMapping[ str, Sequence[types.Artifact]]] = None, executor_output: Optional[execution_result_pb2.ExecutorOutput] = None ) -> None: """Marks an existing execution as success. Also publishes the output artifacts produced by the execution. This method will also merge the executor produced info into system generated output artifacts. The `last_know_state` of the execution will be changed to `COMPLETE` and the output artifacts will be marked as `LIVE`. Args: metadata_handler: A handler to access MLMD. execution_id: The id of the execution to mark successful. contexts: MLMD contexts to associated with the execution. output_artifacts: Output artifacts skeleton of the execution, generated by the system. Each artifact will be linked with the execution through an event with type OUTPUT. executor_output: Executor outputs. `executor_output.output_artifacts` will be used to update system-generated output artifacts passed in through `output_artifacts` arg. There are three contraints to the update: 1. The keys in `executor_output.output_artifacts` are expected to be a subset of the system-generated output artifacts dict. 2. An update to a certain key should contains all the artifacts under that key. 3. An update to an artifact should not change the type of the artifact. Raises: RuntimeError: if the executor output to a output channel is partial. """ output_artifacts = output_artifacts or {} if executor_output: if not set(executor_output.output_artifacts.keys()).issubset( output_artifacts.keys()): raise RuntimeError( 'Executor output %s contains more keys than output skeleton %s.' % (executor_output, output_artifacts)) for key, artifact_list in output_artifacts.items(): if key not in executor_output.output_artifacts: continue updated_artifact_list = executor_output.output_artifacts[ key].artifacts if len(artifact_list) != len(updated_artifact_list): raise RuntimeError( 'Partially update an output channel is not supported') for original, updated in zip(artifact_list, updated_artifact_list): if original.type_id != updated.type_id: raise RuntimeError( 'Executor output should not change artifact type.') original.mlmd_artifact.CopyFrom(updated) # Marks output artifacts as LIVE. for artifact_list in output_artifacts.values(): for artifact in artifact_list: artifact.mlmd_artifact.state = metadata_store_pb2.Artifact.LIVE [execution] = metadata_handler.store.get_executions_by_id([execution_id]) execution.last_known_state = metadata_store_pb2.Execution.COMPLETE execution_lib.put_execution(metadata_handler, execution, contexts, output_artifacts=output_artifacts)
def testGetArtifactsDict(self): with metadata.Metadata(connection_config=self._connection_config) as m: # Create and shuffle a few artifacts. The shuffled order should be # retained in the output of `execution_lib.get_artifacts_dict`. input_artifact_keys = ('input1', 'input2', 'input3') input_artifacts_dict = collections.OrderedDict() for input_key in input_artifact_keys: input_examples = [] for i in range(10): input_example = standard_artifacts.Examples() input_example.uri = f'{input_key}/example{i}' input_example.type_id = common_utils.register_type_if_not_exist( m, input_example.artifact_type).id input_examples.append(input_example) random.shuffle(input_examples) input_artifacts_dict[input_key] = input_examples output_models = [] for i in range(8): output_model = standard_artifacts.Model() output_model.uri = f'model{i}' output_model.type_id = common_utils.register_type_if_not_exist( m, output_model.artifact_type).id output_models.append(output_model) random.shuffle(output_models) output_artifacts_dict = {'model': output_models} # Store input artifacts only. Outputs will be saved in put_execution(). input_mlmd_artifacts = [ a.mlmd_artifact for a in itertools.chain(*input_artifacts_dict.values()) ] artifact_ids = m.store.put_artifacts(input_mlmd_artifacts) for artifact_id, mlmd_artifact in zip(artifact_ids, input_mlmd_artifacts): mlmd_artifact.id = artifact_id execution = execution_lib.prepare_execution( m, metadata_store_pb2.ExecutionType(name='my_execution_type'), state=metadata_store_pb2.Execution.RUNNING) contexts = self._generate_contexts(m) # Change the order of the OrderedDict to shuffle the order of input keys. input_artifacts_dict.move_to_end('input1') execution = execution_lib.put_execution( m, execution, contexts, input_artifacts=input_artifacts_dict, output_artifacts=output_artifacts_dict) # Verify that the same artifacts are returned in the correct order. artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, [metadata_store_pb2.Event.INPUT]) self.assertEqual(set(input_artifact_keys), set(artifacts_dict.keys())) for key in artifacts_dict: self.assertEqual([ex.uri for ex in input_artifacts_dict[key]], [a.uri for a in artifacts_dict[key]], f'for key={key}') artifacts_dict = execution_lib.get_artifacts_dict( m, execution.id, [metadata_store_pb2.Event.OUTPUT]) self.assertEqual({'model'}, set(artifacts_dict.keys())) self.assertEqual([model.uri for model in output_models], [a.uri for a in artifacts_dict['model']])