Exemple #1
0
    def testGetArtifactsDict(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Create and shuffle a few artifacts. The shuffled order should be
            # retained in the output of `execution_lib.get_artifacts_dict`.
            input_examples = []
            for i in range(10):
                input_example = standard_artifacts.Examples()
                input_example.uri = 'example{}'.format(i)
                input_example.type_id = common_utils.register_type_if_not_exist(
                    m, input_example.artifact_type).id
                input_examples.append(input_example)
            random.shuffle(input_examples)
            output_models = []
            for i in range(8):
                output_model = standard_artifacts.Model()
                output_model.uri = 'model{}'.format(i)
                output_model.type_id = common_utils.register_type_if_not_exist(
                    m, output_model.artifact_type).id
                output_models.append(output_model)
            random.shuffle(output_models)
            m.store.put_artifacts([
                a.mlmd_artifact
                for a in itertools.chain(input_examples, output_models)
            ])
            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                state=metadata_store_pb2.Execution.RUNNING)
            contexts = self._generate_contexts(m)
            input_artifacts_dict = {'examples': input_examples}
            output_artifacts_dict = {'model': output_models}
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts=input_artifacts_dict,
                output_artifacts=output_artifacts_dict)

            # Verify that the same artifacts are returned in the correct order.
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, metadata_store_pb2.Event.INPUT)
            self.assertCountEqual(['examples'], list(artifacts_dict.keys()))
            self.assertEqual([ex.uri for ex in input_examples],
                             [a.uri for a in artifacts_dict['examples']])
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, metadata_store_pb2.Event.OUTPUT)
            self.assertCountEqual(['model'], list(artifacts_dict.keys()))
            self.assertEqual([model.uri for model in output_models],
                             [a.uri for a in artifacts_dict['model']])
Exemple #2
0
def get_all_node_artifacts(
    pipeline: pipeline_pb2.Pipeline, mlmd_handle: metadata.Metadata
) -> Dict[str, Dict[int, Dict[str, List[metadata_store_pb2.Artifact]]]]:
    """Returns all output artifacts of all pipeline nodes if present.

  Args:
    pipeline: Pipeline proto associated with a `PipelineState` object.
    mlmd_handle: Handle to MLMD db.

  Returns:
    Dict of node id to Dict of execution id to Dict of key to output artifact
    list.
  """
    executions_dict = get_all_node_executions(pipeline, mlmd_handle)
    result = {}
    for node_id, executions in executions_dict.items():
        node_artifacts = {}
        for execution in executions:
            execution_artifacts = {}
            for key, artifacts in execution_lib.get_artifacts_dict(
                    mlmd_handle, execution.id, [
                        metadata_store_pb2.Event.OUTPUT,
                        metadata_store_pb2.Event.DECLARED_OUTPUT
                    ]).items():
                execution_artifacts[key] = [
                    artifact.mlmd_artifact for artifact in artifacts
                ]
            node_artifacts[execution.id] = execution_artifacts
        result[node_id] = node_artifacts
    return result
Exemple #3
0
def _generate_task_from_execution(
        metadata_handler: metadata.Metadata, pipeline: pipeline_pb2.Pipeline,
        node: pipeline_pb2.PipelineNode,
        execution: metadata_store_pb2.Execution) -> task_lib.Task:
    """Generates `ExecNodeTask` given execution."""
    contexts = metadata_handler.store.get_contexts_by_execution(execution.id)
    exec_properties = _extract_properties(execution)
    input_artifacts = execution_lib.get_artifacts_dict(
        metadata_handler, execution.id, metadata_store_pb2.Event.INPUT)
    outputs_resolver = outputs_utils.OutputsResolver(node,
                                                     pipeline.pipeline_info,
                                                     pipeline.runtime_spec,
                                                     pipeline.execution_mode)
    return task_lib.ExecNodeTask(
        node_uid=task_lib.NodeUid.from_pipeline_node(pipeline, node),
        execution=execution,
        contexts=contexts,
        exec_properties=exec_properties,
        input_artifacts=input_artifacts,
        output_artifacts=outputs_resolver.generate_output_artifacts(
            execution.id),
        executor_output_uri=outputs_resolver.get_executor_output_uri(
            execution.id),
        stateful_working_dir=outputs_resolver.get_stateful_working_directory(
            execution.id))
Exemple #4
0
    def _cache_and_publish(self,
                           existing_execution: metadata_store_pb2.Execution):
        """Updates MLMD."""
        cached_execution_contexts = self._get_cached_execution_contexts(
            existing_execution)
        # Check if there are any previous attempts to cache and publish.
        prev_cache_executions = (
            execution_lib.get_executions_associated_with_all_contexts(
                self._mlmd, contexts=cached_execution_contexts))
        if not prev_cache_executions:
            new_execution = execution_publish_utils.register_execution(
                self._mlmd,
                execution_type=metadata_store_pb2.ExecutionType(
                    id=existing_execution.type_id),
                contexts=cached_execution_contexts)
        else:
            if len(prev_cache_executions) > 1:
                logging.warning(
                    'More than one previous cache executions seen when attempting '
                    'reuse_node_outputs: %s', prev_cache_executions)

            if (prev_cache_executions[-1].last_known_state ==
                    metadata_store_pb2.Execution.CACHED):
                return
            else:
                new_execution = prev_cache_executions[-1]

        output_artifacts = execution_lib.get_artifacts_dict(
            self._mlmd,
            existing_execution.id,
            event_types=list(event_lib.VALID_OUTPUT_EVENT_TYPES))

        execution_publish_utils.publish_cached_execution(
            self._mlmd,
            contexts=cached_execution_contexts,
            execution_id=new_execution.id,
            output_artifacts=output_artifacts)
Exemple #5
0
    def testGetArtifactsDict(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            # Create and shuffle a few artifacts. The shuffled order should be
            # retained in the output of `execution_lib.get_artifacts_dict`.
            input_artifact_keys = ('input1', 'input2', 'input3')
            input_artifacts_dict = collections.OrderedDict()
            for input_key in input_artifact_keys:
                input_examples = []
                for i in range(10):
                    input_example = standard_artifacts.Examples()
                    input_example.uri = f'{input_key}/example{i}'
                    input_example.type_id = common_utils.register_type_if_not_exist(
                        m, input_example.artifact_type).id
                    input_examples.append(input_example)
                random.shuffle(input_examples)
                input_artifacts_dict[input_key] = input_examples

            output_models = []
            for i in range(8):
                output_model = standard_artifacts.Model()
                output_model.uri = f'model{i}'
                output_model.type_id = common_utils.register_type_if_not_exist(
                    m, output_model.artifact_type).id
                output_models.append(output_model)
            random.shuffle(output_models)
            output_artifacts_dict = {'model': output_models}

            # Store input artifacts only. Outputs will be saved in put_execution().
            input_mlmd_artifacts = [
                a.mlmd_artifact
                for a in itertools.chain(*input_artifacts_dict.values())
            ]
            artifact_ids = m.store.put_artifacts(input_mlmd_artifacts)
            for artifact_id, mlmd_artifact in zip(artifact_ids,
                                                  input_mlmd_artifacts):
                mlmd_artifact.id = artifact_id

            execution = execution_lib.prepare_execution(
                m,
                metadata_store_pb2.ExecutionType(name='my_execution_type'),
                state=metadata_store_pb2.Execution.RUNNING)
            contexts = self._generate_contexts(m)

            # Change the order of the OrderedDict to shuffle the order of input keys.
            input_artifacts_dict.move_to_end('input1')
            execution = execution_lib.put_execution(
                m,
                execution,
                contexts,
                input_artifacts=input_artifacts_dict,
                output_artifacts=output_artifacts_dict)

            # Verify that the same artifacts are returned in the correct order.
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, [metadata_store_pb2.Event.INPUT])
            self.assertEqual(set(input_artifact_keys),
                             set(artifacts_dict.keys()))
            for key in artifacts_dict:
                self.assertEqual([ex.uri for ex in input_artifacts_dict[key]],
                                 [a.uri for a in artifacts_dict[key]],
                                 f'for key={key}')
            artifacts_dict = execution_lib.get_artifacts_dict(
                m, execution.id, [metadata_store_pb2.Event.OUTPUT])
            self.assertEqual({'model'}, set(artifacts_dict.keys()))
            self.assertEqual([model.uri for model in output_models],
                             [a.uri for a in artifacts_dict['model']])