示例#1
0
def publish_internal_execution(
    metadata_handler: metadata.Metadata,
    contexts: Sequence[metadata_store_pb2.Context],
    execution_id: int,
    output_artifacts: Optional[MutableMapping[str,
                                              Sequence[types.Artifact]]] = None
) -> None:
    """Marks an exeisting execution as as success and links its output to an INTERNAL_OUTPUT event.

  Args:
    metadata_handler: A handler to access MLMD.
    contexts: MLMD contexts to associated with the execution.
    execution_id: The id of the execution.
    output_artifacts: Output artifacts of the execution. Each artifact will be
      linked with the execution through an event with type INTERNAL_OUTPUT.
  """
    [execution] = metadata_handler.store.get_executions_by_id([execution_id])
    execution.last_known_state = metadata_store_pb2.Execution.COMPLETE

    execution_lib.put_execution(
        metadata_handler,
        execution,
        contexts,
        output_artifacts=output_artifacts,
        output_event_type=metadata_store_pb2.Event.INTERNAL_OUTPUT)
示例#2
0
def publish_cached_execution(
    metadata_handler: metadata.Metadata,
    contexts: Sequence[metadata_store_pb2.Context],
    execution_id: int,
    output_artifacts: Optional[MutableMapping[str,
                                              Sequence[types.Artifact]]] = None,
) -> None:
  """Marks an existing execution as using cached outputs from a previous execution.

  Args:
    metadata_handler: A handler to access MLMD.
    contexts: MLMD contexts to associated with the execution.
    execution_id: The id of the execution.
    output_artifacts: Output artifacts of the execution. Each artifact will be
      linked with the execution through an event with type OUTPUT.
  """
  [execution] = metadata_handler.store.get_executions_by_id([execution_id])
  execution.last_known_state = metadata_store_pb2.Execution.CACHED

  execution_lib.put_execution(
      metadata_handler,
      execution,
      contexts,
      input_artifacts=None,
      output_artifacts=output_artifacts)
示例#3
0
def publish_failed_execution(metadata_handler: metadata.Metadata,
                             contexts: Sequence[metadata_store_pb2.Context],
                             execution_id: int) -> None:
    """Marks an existing execution as failed.

  Args:
    metadata_handler: A handler to access MLMD.
    contexts: MLMD contexts to associated with the execution.
    execution_id: The id of the execution.
  """
    [execution] = metadata_handler.store.get_executions_by_id([execution_id])
    execution.last_known_state = metadata_store_pb2.Execution.FAILED

    execution_lib.put_execution(metadata_handler, execution, contexts)
示例#4
0
def register_execution(
    metadata_handler: metadata.Metadata,
    execution_type: metadata_store_pb2.ExecutionType,
    contexts: Sequence[metadata_store_pb2.Context],
    input_artifacts: Optional[MutableMapping[str,
                                             Sequence[types.Artifact]]] = None,
    exec_properties: Optional[Mapping[str, types.Property]] = None,
) -> metadata_store_pb2.Execution:
    """Registers a new execution in MLMD.

  Along with the execution:
  -  the input artifacts will be linked to the execution.
  -  the contexts will be linked to both the execution and its input artifacts.

  Args:
    metadata_handler: A handler to access MLMD.
    execution_type: The type of the execution.
    contexts: MLMD contexts to associated with the execution.
    input_artifacts: Input artifacts of the execution. Each artifact will be
      linked with the execution through an event.
    exec_properties: Execution properties. Will be attached to the execution.

  Returns:
    An MLMD execution that is registered in MLMD, with id populated.
  """
    execution = execution_lib.prepare_execution(
        metadata_handler, execution_type, metadata_store_pb2.Execution.RUNNING,
        exec_properties)
    return execution_lib.put_execution(metadata_handler,
                                       execution,
                                       contexts,
                                       input_artifacts=input_artifacts)
示例#5
0
    def testPutExecutionGraph(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Prepares an input artifact. The artifact should be registered in MLMD
            # before the put_execution call.
            input_example = standard_artifacts.Examples()
            input_example.uri = 'example'
            input_example.type_id = common_utils.register_type_if_not_exist(
                m, input_example.artifact_type).id
            [input_example.id
             ] = m.store.put_artifacts([input_example.mlmd_artifact])
            # Prepares an output artifact.
            output_model = standard_artifacts.Model()
            output_model.uri = 'model'
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                exec_properties={
                    'p1': 1,
                    'p2': '2'
                },
                state=metadata_store_pb2.Execution.COMPLETE)
            contexts = self._generate_contexts(m)
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts={'example': [input_example]},
                output_artifacts={'model': [output_model]})

            self.assertProtoPartiallyEquals(
                output_model.mlmd_artifact,
                m.store.get_artifacts_by_id([output_model.id])[0],
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
            # Verifies edges between artifacts and execution.
            [input_event
             ] = m.store.get_events_by_artifact_ids([input_example.id])
            self.assertEqual(input_event.execution_id, execution.id)
            self.assertEqual(input_event.type, metadata_store_pb2.Event.INPUT)
            [output_event
             ] = m.store.get_events_by_artifact_ids([output_model.id])
            self.assertEqual(output_event.execution_id, execution.id)
            self.assertEqual(output_event.type,
                             metadata_store_pb2.Event.OUTPUT)
            # Verifies edges connecting contexts and {artifacts, execution}.
            context_ids = [context.id for context in contexts]
            self.assertCountEqual([
                c.id
                for c in m.store.get_contexts_by_artifact(input_example.id)
            ], context_ids)
            self.assertCountEqual([
                c.id for c in m.store.get_contexts_by_artifact(output_model.id)
            ], context_ids)
            self.assertCountEqual([
                c.id for c in m.store.get_contexts_by_execution(execution.id)
            ], context_ids)
示例#6
0
 def commit(self) -> None:
     """Commits pipeline state to MLMD if there are any mutations."""
     if self._commit:
         self.execution = execution_lib.put_execution(
             self.mlmd_handle, self.execution, [self.context])
         logging.info(
             'Committed execution (id: %s) for pipeline with uid: %s',
             self.execution.id, self.pipeline_uid)
     self._commit = False
示例#7
0
def publish_failed_execution(
    metadata_handler: metadata.Metadata,
    contexts: Sequence[metadata_store_pb2.Context],
    execution_id: int,
    executor_output: Optional[execution_result_pb2.ExecutorOutput] = None
) -> None:
    """Marks an existing execution as failed.

  Args:
    metadata_handler: A handler to access MLMD.
    contexts: MLMD contexts to associated with the execution.
    execution_id: The id of the execution.
    executor_output: The output of executor.
  """
    [execution] = metadata_handler.store.get_executions_by_id([execution_id])
    execution.last_known_state = metadata_store_pb2.Execution.FAILED
    _set_execution_result_if_not_empty(executor_output, execution)

    execution_lib.put_execution(metadata_handler, execution, contexts)
示例#8
0
    def testGetExecutionsAssociatedWithAllContexts(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            self.assertLen(contexts, 2)

            # Create 2 executions and associate with one context each.
            execution1 = execution_lib.prepare_execution(
                m, metadata_store_pb2.ExecutionType(name='my_execution_type'),
                metadata_store_pb2.Execution.RUNNING)
            execution1 = execution_lib.put_execution(m, execution1,
                                                     [contexts[0]])
            execution2 = execution_lib.prepare_execution(
                m, metadata_store_pb2.ExecutionType(name='my_execution_type'),
                metadata_store_pb2.Execution.COMPLETE)
            execution2 = execution_lib.put_execution(m, execution2,
                                                     [contexts[1]])

            # Create another execution and associate with both contexts.
            execution3 = execution_lib.prepare_execution(
                m, metadata_store_pb2.ExecutionType(name='my_execution_type'),
                metadata_store_pb2.Execution.NEW)
            execution3 = execution_lib.put_execution(m, execution3, contexts)

            # Verify that the right executions are returned.
            with self.subTest(for_contexts=(0, )):
                executions = execution_lib.get_executions_associated_with_all_contexts(
                    m, [contexts[0]])
                self.assertCountEqual([execution1.id, execution3.id],
                                      [e.id for e in executions])
            with self.subTest(for_contexts=(1, )):
                executions = execution_lib.get_executions_associated_with_all_contexts(
                    m, [contexts[1]])
                self.assertCountEqual([execution2.id, execution3.id],
                                      [e.id for e in executions])
            with self.subTest(for_contexts=(0, 1)):
                executions = execution_lib.get_executions_associated_with_all_contexts(
                    m, contexts)
                self.assertCountEqual([execution3.id],
                                      [e.id for e in executions])
示例#9
0
    def new(cls, mlmd_handle: metadata.Metadata,
            pipeline: pipeline_pb2.Pipeline) -> 'PipelineState':
        """Creates a `PipelineState` object for a new pipeline.

    No active pipeline with the same pipeline uid should exist for the call to
    be successful.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: IR of the pipeline.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: If a pipeline with same UID already exists.
    """
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        context = context_lib.register_context_if_not_exists(
            mlmd_handle,
            context_type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))

        executions = mlmd_handle.store.get_executions_by_context(context.id)
        if any(e for e in executions if execution_lib.is_execution_active(e)):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=f'Pipeline with uid {pipeline_uid} already active.')

        execution = execution_lib.prepare_execution(
            mlmd_handle,
            _ORCHESTRATOR_EXECUTION_TYPE,
            metadata_store_pb2.Execution.NEW,
            exec_properties={
                _PIPELINE_IR:
                base64.b64encode(pipeline.SerializeToString()).decode('utf-8')
            },
        )
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            data_types_utils.set_metadata_value(
                execution.custom_properties[_PIPELINE_RUN_ID],
                pipeline.runtime_spec.pipeline_run_id.field_value.string_value)

        execution = execution_lib.put_execution(mlmd_handle, execution,
                                                [context])
        record_state_change_time()

        return cls(mlmd_handle=mlmd_handle,
                   pipeline=pipeline,
                   execution_id=execution.id)
示例#10
0
    def testGetArtifactsDict(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Create and shuffle a few artifacts. The shuffled order should be
            # retained in the output of `execution_lib.get_artifacts_dict`.
            input_examples = []
            for i in range(10):
                input_example = standard_artifacts.Examples()
                input_example.uri = 'example{}'.format(i)
                input_example.type_id = common_utils.register_type_if_not_exist(
                    m, input_example.artifact_type).id
                input_examples.append(input_example)
            random.shuffle(input_examples)
            output_models = []
            for i in range(8):
                output_model = standard_artifacts.Model()
                output_model.uri = 'model{}'.format(i)
                output_model.type_id = common_utils.register_type_if_not_exist(
                    m, output_model.artifact_type).id
                output_models.append(output_model)
            random.shuffle(output_models)
            m.store.put_artifacts([
                a.mlmd_artifact
                for a in itertools.chain(input_examples, output_models)
            ])
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                state=metadata_store_pb2.Execution.RUNNING)
            contexts = self._generate_contexts(m)
            input_artifacts_dict = {'examples': input_examples}
            output_artifacts_dict = {'model': output_models}
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts=input_artifacts_dict,
                output_artifacts=output_artifacts_dict)

            # Verify that the same artifacts are returned in the correct order.
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, metadata_store_pb2.Event.INPUT)
            self.assertCountEqual(['examples'], list(artifacts_dict.keys()))
            self.assertEqual([ex.uri for ex in input_examples],
                             [a.uri for a in artifacts_dict['examples']])
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, metadata_store_pb2.Event.OUTPUT)
            self.assertCountEqual(['model'], list(artifacts_dict.keys()))
            self.assertEqual([model.uri for model in output_models],
                             [a.uri for a in artifacts_dict['model']])
示例#11
0
def initiate_pipeline_start(
        mlmd_handle: metadata.Metadata,
        pipeline: pipeline_pb2.Pipeline) -> metadata_store_pb2.Execution:
    """Initiates a pipeline start operation.

  Upon success, MLMD is updated to signal that the given pipeline must be
  started.

  Args:
    mlmd_handle: A handle to the MLMD db.
    pipeline: IR of the pipeline to start.

  Returns:
    The pipeline-level MLMD execution proto upon success.

  Raises:
    status_lib.StatusNotOkError: Failure to initiate pipeline start or if
      execution is not inactive after waiting `timeout_secs`.
  """
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    context = context_lib.register_context_if_not_exists(
        mlmd_handle,
        context_type_name=_ORCHESTRATOR_RESERVED_ID,
        context_name=_orchestrator_context_name(pipeline_uid))

    executions = mlmd_handle.store.get_executions_by_context(context.id)
    if any(e for e in executions if execution_lib.is_execution_active(e)):
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.ALREADY_EXISTS,
            message=f'Pipeline with uid {pipeline_uid} already started.')

    execution = execution_lib.prepare_execution(
        mlmd_handle,
        _ORCHESTRATOR_EXECUTION_TYPE,
        metadata_store_pb2.Execution.NEW,
        exec_properties={
            _PIPELINE_IR:
            base64.b64encode(pipeline.SerializeToString()).decode('utf-8')
        })
    execution = execution_lib.put_execution(mlmd_handle, execution, [context])
    logging.info('Registered execution (id: %s) for the pipeline with uid: %s',
                 execution.id, pipeline_uid)
    return execution
示例#12
0
    def testGetArtifactIdsForExecutionIdGroupedByEventType(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Register an input and output artifacts in MLMD.
            input_example = standard_artifacts.Examples()
            input_example.uri = 'example'
            input_example.type_id = common_utils.register_type_if_not_exist(
                m, input_example.artifact_type).id
            output_model = standard_artifacts.Model()
            output_model.uri = 'model'
            output_model.type_id = common_utils.register_type_if_not_exist(
                m, output_model.artifact_type).id
            [input_example.id, output_model.id] = m.store.put_artifacts(
                [input_example.mlmd_artifact, output_model.mlmd_artifact])
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                exec_properties={
                    'p1': 1,
                    'p2': '2'
                },
                state=metadata_store_pb2.Execution.COMPLETE)
            contexts = self._generate_contexts(m)
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts={'example': [input_example]},
                output_artifacts={'model': [output_model]})

            artifact_ids_by_event_type = (
                execution_lib.get_artifact_ids_by_event_type_for_execution_id(
                    m, execution.id))
            self.assertDictEqual(
                {
                    metadata_store_pb2.Event.INPUT: set([input_example.id]),
                    metadata_store_pb2.Event.OUTPUT: set([output_model.id]),
                }, artifact_ids_by_event_type)
示例#13
0
def publish_succeeded_execution(
    metadata_handler: metadata.Metadata,
    execution_id: int,
    contexts: Sequence[metadata_store_pb2.Context],
    output_artifacts: Optional[MutableMapping[
        str, Sequence[types.Artifact]]] = None,
    executor_output: Optional[execution_result_pb2.ExecutorOutput] = None
) -> Optional[MutableMapping[str, List[types.Artifact]]]:
    """Marks an existing execution as success.

  Also publishes the output artifacts produced by the execution. This method
  will also merge the executor produced info into system generated output
  artifacts. The `last_know_state` of the execution will be changed to
  `COMPLETE` and the output artifacts will be marked as `LIVE`.

  Args:
    metadata_handler: A handler to access MLMD.
    execution_id: The id of the execution to mark successful.
    contexts: MLMD contexts to associated with the execution.
    output_artifacts: Output artifacts skeleton of the execution, generated by
      the system. Each artifact will be linked with the execution through an
      event with type OUTPUT.
    executor_output: Executor outputs. `executor_output.output_artifacts` will
      be used to update system-generated output artifacts passed in through
      `output_artifacts` arg. There are three contraints to the update: 1. The
        keys in `executor_output.output_artifacts` are expected to be a subset
        of the system-generated output artifacts dict. 2. An update to a certain
        key should contains all the artifacts under that key. 3. An update to an
        artifact should not change the type of the artifact.

  Returns:
    The maybe updated output_artifacts, note that only outputs whose key are in
    executor_output will be updated and others will be untouched. That said,
    it can be partially updated.
  Raises:
    RuntimeError: if the executor output to a output channel is partial.
  """
    output_artifacts = copy.deepcopy(output_artifacts) or {}
    output_artifacts = cast(MutableMapping[str, List[types.Artifact]],
                            output_artifacts)
    if executor_output:
        if not set(executor_output.output_artifacts.keys()).issubset(
                output_artifacts.keys()):
            raise RuntimeError(
                'Executor output %s contains more keys than output skeleton %s.'
                % (executor_output, output_artifacts))
        for key, artifact_list in output_artifacts.items():
            if key not in executor_output.output_artifacts:
                continue
            updated_artifact_list = executor_output.output_artifacts[
                key].artifacts

            # We assume the original output dict must include at least one output
            # artifact and all artifacts in the list share the same type.
            original_artifact = artifact_list[0]
            # Update the artifact list with what's in the executor output
            artifact_list.clear()
            # TODO(b/175426744): revisit this:
            # 1) Whether multiple output is needed or not after TFX componets
            #    are upgraded.
            # 2) If multiple output are needed and is a common practice, should we
            #    use driver instead to create the list of output artifact instead
            #    of letting executor to create them.
            for proto_artifact in updated_artifact_list:
                _check_validity(proto_artifact, original_artifact)
                python_artifact = types.Artifact(
                    original_artifact.artifact_type)
                python_artifact.set_mlmd_artifact(proto_artifact)
                artifact_list.append(python_artifact)

    # Marks output artifacts as LIVE.
    for artifact_list in output_artifacts.values():
        for artifact in artifact_list:
            artifact.mlmd_artifact.state = metadata_store_pb2.Artifact.LIVE

    [execution] = metadata_handler.store.get_executions_by_id([execution_id])
    execution.last_known_state = metadata_store_pb2.Execution.COMPLETE

    execution_lib.put_execution(metadata_handler,
                                execution,
                                contexts,
                                output_artifacts=output_artifacts)

    return output_artifacts
示例#14
0
    def new(
        cls,
        mlmd_handle: metadata.Metadata,
        pipeline: pipeline_pb2.Pipeline,
        pipeline_run_metadata: Optional[Mapping[str, types.Property]] = None,
    ) -> 'PipelineState':
        """Creates a `PipelineState` object for a new pipeline.

    No active pipeline with the same pipeline uid should exist for the call to
    be successful.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: IR of the pipeline.
      pipeline_run_metadata: Pipeline run metadata.

    Returns:
      A `PipelineState` object.

    Raises:
      status_lib.StatusNotOkError: If a pipeline with same UID already exists.
    """
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        context = context_lib.register_context_if_not_exists(
            mlmd_handle,
            context_type_name=_ORCHESTRATOR_RESERVED_ID,
            context_name=orchestrator_context_name(pipeline_uid))

        executions = mlmd_handle.store.get_executions_by_context(context.id)
        if any(e for e in executions if execution_lib.is_execution_active(e)):
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.ALREADY_EXISTS,
                message=f'Pipeline with uid {pipeline_uid} already active.')

        exec_properties = {_PIPELINE_IR: _base64_encode(pipeline)}
        if pipeline_run_metadata:
            exec_properties[_PIPELINE_RUN_METADATA] = json_utils.dumps(
                pipeline_run_metadata)

        execution = execution_lib.prepare_execution(
            mlmd_handle,
            _ORCHESTRATOR_EXECUTION_TYPE,
            metadata_store_pb2.Execution.NEW,
            exec_properties=exec_properties)
        if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
            data_types_utils.set_metadata_value(
                execution.custom_properties[_PIPELINE_RUN_ID],
                pipeline.runtime_spec.pipeline_run_id.field_value.string_value)
            # Set the node state to COMPLETE for any nodes that are marked to be
            # skipped in a partial pipeline run.
            node_states_dict = {}
            for node in get_all_pipeline_nodes(pipeline):
                if node.execution_options.HasField('skip'):
                    logging.info('Node %s is skipped in this partial run.',
                                 node.node_info.id)
                    node_states_dict[node.node_info.id] = NodeState(
                        state=NodeState.COMPLETE)
            if node_states_dict:
                _save_node_states_dict(execution, node_states_dict)

        execution = execution_lib.put_execution(mlmd_handle, execution,
                                                [context])
        record_state_change_time()

        return cls(mlmd_handle=mlmd_handle,
                   pipeline=pipeline,
                   execution_id=execution.id)
示例#15
0
def publish_succeeded_execution(
    metadata_handler: metadata.Metadata,
    execution_id: int,
    contexts: Sequence[metadata_store_pb2.Context],
    output_artifacts: Optional[MutableMapping[
        str, Sequence[types.Artifact]]] = None,
    executor_output: Optional[execution_result_pb2.ExecutorOutput] = None
) -> None:
    """Marks an existing execution as success.

  Also publishes the output artifacts produced by the execution. This method
  will also merge the executor produced info into system generated output
  artifacts. The `last_know_state` of the execution will be changed to
  `COMPLETE` and the output artifacts will be marked as `LIVE`.

  Args:
    metadata_handler: A handler to access MLMD.
    execution_id: The id of the execution to mark successful.
    contexts: MLMD contexts to associated with the execution.
    output_artifacts: Output artifacts skeleton of the execution, generated by
      the system. Each artifact will be linked with the execution through an
      event with type OUTPUT.
    executor_output: Executor outputs. `executor_output.output_artifacts` will
      be used to update system-generated output artifacts passed in through
      `output_artifacts` arg. There are three contraints to the update: 1. The
        keys in `executor_output.output_artifacts` are expected to be a subset
        of the system-generated output artifacts dict. 2. An update to a certain
        key should contains all the artifacts under that key. 3. An update to an
        artifact should not change the type of the artifact.

  Raises:
    RuntimeError: if the executor output to a output channel is partial.
  """
    output_artifacts = output_artifacts or {}
    if executor_output:
        if not set(executor_output.output_artifacts.keys()).issubset(
                output_artifacts.keys()):
            raise RuntimeError(
                'Executor output %s contains more keys than output skeleton %s.'
                % (executor_output, output_artifacts))
        for key, artifact_list in output_artifacts.items():
            if key not in executor_output.output_artifacts:
                continue
            updated_artifact_list = executor_output.output_artifacts[
                key].artifacts
            if len(artifact_list) != len(updated_artifact_list):
                raise RuntimeError(
                    'Partially update an output channel is not supported')

            for original, updated in zip(artifact_list, updated_artifact_list):
                if original.type_id != updated.type_id:
                    raise RuntimeError(
                        'Executor output should not change artifact type.')
                original.mlmd_artifact.CopyFrom(updated)

    # Marks output artifacts as LIVE.
    for artifact_list in output_artifacts.values():
        for artifact in artifact_list:
            artifact.mlmd_artifact.state = metadata_store_pb2.Artifact.LIVE

    [execution] = metadata_handler.store.get_executions_by_id([execution_id])
    execution.last_known_state = metadata_store_pb2.Execution.COMPLETE

    execution_lib.put_execution(metadata_handler,
                                execution,
                                contexts,
                                output_artifacts=output_artifacts)
示例#16
0
    def testGetArtifactsDict(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Create and shuffle a few artifacts. The shuffled order should be
            # retained in the output of `execution_lib.get_artifacts_dict`.
            input_artifact_keys = ('input1', 'input2', 'input3')
            input_artifacts_dict = collections.OrderedDict()
            for input_key in input_artifact_keys:
                input_examples = []
                for i in range(10):
                    input_example = standard_artifacts.Examples()
                    input_example.uri = f'{input_key}/example{i}'
                    input_example.type_id = common_utils.register_type_if_not_exist(
                        m, input_example.artifact_type).id
                    input_examples.append(input_example)
                random.shuffle(input_examples)
                input_artifacts_dict[input_key] = input_examples

            output_models = []
            for i in range(8):
                output_model = standard_artifacts.Model()
                output_model.uri = f'model{i}'
                output_model.type_id = common_utils.register_type_if_not_exist(
                    m, output_model.artifact_type).id
                output_models.append(output_model)
            random.shuffle(output_models)
            output_artifacts_dict = {'model': output_models}

            # Store input artifacts only. Outputs will be saved in put_execution().
            input_mlmd_artifacts = [
                a.mlmd_artifact
                for a in itertools.chain(*input_artifacts_dict.values())
            ]
            artifact_ids = m.store.put_artifacts(input_mlmd_artifacts)
            for artifact_id, mlmd_artifact in zip(artifact_ids,
                                                  input_mlmd_artifacts):
                mlmd_artifact.id = artifact_id

            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                state=metadata_store_pb2.Execution.RUNNING)
            contexts = self._generate_contexts(m)

            # Change the order of the OrderedDict to shuffle the order of input keys.
            input_artifacts_dict.move_to_end('input1')
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts=input_artifacts_dict,
                output_artifacts=output_artifacts_dict)

            # Verify that the same artifacts are returned in the correct order.
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, [metadata_store_pb2.Event.INPUT])
            self.assertEqual(set(input_artifact_keys),
                             set(artifacts_dict.keys()))
            for key in artifacts_dict:
                self.assertEqual([ex.uri for ex in input_artifacts_dict[key]],
                                 [a.uri for a in artifacts_dict[key]],
                                 f'for key={key}')
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, [metadata_store_pb2.Event.OUTPUT])
            self.assertEqual({'model'}, set(artifacts_dict.keys()))
            self.assertEqual([model.uri for model in output_models],
                             [a.uri for a in artifacts_dict['model']])