Beispiel #1
0
def get_executions(
        metadata_handler: metadata.Metadata,
        node: pipeline_pb2.PipelineNode) -> List[metadata_store_pb2.Execution]:
    """Returns all executions for the given pipeline node.

  This finds all executions having the same set of contexts as the pipeline
  node.

  Args:
    metadata_handler: A handler to access MLMD db.
    node: The pipeline node for which to obtain executions.

  Returns:
    List of executions for the given node in MLMD db.
  """
    # Get all the contexts associated with the node.
    contexts = []
    for context_spec in node.contexts.contexts:
        context = metadata_handler.store.get_context_by_type_and_name(
            context_spec.type.name, common_utils.get_value(context_spec.name))
        if context is None:
            # If no context is registered, it's certain that there is no
            # associated execution for the node.
            return []
        contexts.append(context)
    return execution_lib.get_executions_associated_with_all_contexts(
        metadata_handler, contexts)
Beispiel #2
0
def get_qualified_artifacts(
    metadata_handler: metadata.Metadata,
    contexts: Iterable[metadata_store_pb2.Context],
    artifact_type: metadata_store_pb2.ArtifactType,
    output_key: Optional[str] = None,
) -> List[types.Artifact]:
    """Gets qualified artifacts that have the right producer info.

  Args:
    metadata_handler: A metadata handler to access MLMD store.
    contexts: Context constraints to filter artifacts
    artifact_type: Type constraint to filter artifacts
    output_key: Output key constraint to filter artifacts

  Returns:
    A list of qualified TFX Artifacts.
  """
    # We expect to have at least one context for input resolution.
    assert contexts, 'Must have at least one context.'

    try:
        artifact_type_name = artifact_type.name
        artifact_type = metadata_handler.store.get_artifact_type(
            artifact_type_name)
    except mlmd.errors.NotFoundError:
        logging.warning('Artifact type %s is not found in MLMD.',
                        artifact_type.name)
        artifact_type = None

    if not artifact_type:
        return []

    executions_within_context = (
        execution_lib.get_executions_associated_with_all_contexts(
            metadata_handler, contexts))

    # Filters out non-success executions.
    qualified_producer_executions = [
        e.id for e in executions_within_context
        if execution_lib.is_execution_successful(e)
    ]
    # Gets the output events that have the matched output key.
    qualified_output_events = [
        ev for ev in metadata_handler.store.get_events_by_execution_ids(
            qualified_producer_executions)
        if event_lib.validate_output_event(ev, output_key)
    ]

    # Gets the candidate artifacts from output events.
    candidate_artifacts = metadata_handler.store.get_artifacts_by_id(
        list(set(ev.artifact_id for ev in qualified_output_events)))
    # Filters the artifacts that have the right artifact type and state.
    qualified_artifacts = [
        a for a in candidate_artifacts if a.type_id == artifact_type.id
        and a.state == metadata_store_pb2.Artifact.LIVE
    ]
    return [
        artifact_utils.deserialize_artifact(artifact_type, a)
        for a in qualified_artifacts
    ]
Beispiel #3
0
    def _get_successful_executions(
            self, node_id: str,
            run_id: str) -> List[metadata_store_pb2.Execution]:
        """Gets all successful Executions of a given node in a given pipeline run.

    Args:
      node_id: The node whose Executions to query.
      run_id: The pipeline run id to query the Executions from.

    Returns:
      All successful executions for that node at that run_id.

    Raises:
      LookupError: If no successful Execution was found.
    """
        node_context = self._get_node_context(node_id)
        base_run_context = self._get_pipeline_run_context(run_id)
        all_associated_executions = (
            execution_lib.get_executions_associated_with_all_contexts(
                self._mlmd,
                contexts=[
                    node_context, base_run_context, self._pipeline_context
                ]))
        prev_successful_executions = [
            e for e in all_associated_executions
            if execution_lib.is_execution_successful(e)
        ]
        if not prev_successful_executions:
            raise LookupError(
                f'No previous successful executions found for node_id {node_id} in '
                f'pipeline_run {run_id}')

        return execution_lib.sort_executions_newest_to_oldest(
            prev_successful_executions)
Beispiel #4
0
    def testGetExecutionsAssociatedWithAllContexts(self):
        with metadata.Metadata(connection_config=self._connection_config) as m:
            contexts = self._generate_contexts(m)
            self.assertLen(contexts, 2)

            # Create 2 executions and associate with one context each.
            execution1 = execution_lib.prepare_execution(
                m, metadata_store_pb2.ExecutionType(name='my_execution_type'),
                metadata_store_pb2.Execution.RUNNING)
            execution1 = execution_lib.put_execution(m, execution1,
                                                     [contexts[0]])
            execution2 = execution_lib.prepare_execution(
                m, metadata_store_pb2.ExecutionType(name='my_execution_type'),
                metadata_store_pb2.Execution.COMPLETE)
            execution2 = execution_lib.put_execution(m, execution2,
                                                     [contexts[1]])

            # Create another execution and associate with both contexts.
            execution3 = execution_lib.prepare_execution(
                m, metadata_store_pb2.ExecutionType(name='my_execution_type'),
                metadata_store_pb2.Execution.NEW)
            execution3 = execution_lib.put_execution(m, execution3, contexts)

            # Verify that the right executions are returned.
            with self.subTest(for_contexts=(0, )):
                executions = execution_lib.get_executions_associated_with_all_contexts(
                    m, [contexts[0]])
                self.assertCountEqual([execution1.id, execution3.id],
                                      [e.id for e in executions])
            with self.subTest(for_contexts=(1, )):
                executions = execution_lib.get_executions_associated_with_all_contexts(
                    m, [contexts[1]])
                self.assertCountEqual([execution2.id, execution3.id],
                                      [e.id for e in executions])
            with self.subTest(for_contexts=(0, 1)):
                executions = execution_lib.get_executions_associated_with_all_contexts(
                    m, contexts)
                self.assertCountEqual([execution3.id],
                                      [e.id for e in executions])
Beispiel #5
0
    def _cache_and_publish(self,
                           existing_execution: metadata_store_pb2.Execution):
        """Updates MLMD."""
        cached_execution_contexts = self._get_cached_execution_contexts(
            existing_execution)
        # Check if there are any previous attempts to cache and publish.
        prev_cache_executions = (
            execution_lib.get_executions_associated_with_all_contexts(
                self._mlmd, contexts=cached_execution_contexts))
        if not prev_cache_executions:
            new_execution = execution_publish_utils.register_execution(
                self._mlmd,
                execution_type=metadata_store_pb2.ExecutionType(
                    id=existing_execution.type_id),
                contexts=cached_execution_contexts)
        else:
            if len(prev_cache_executions) > 1:
                logging.warning(
                    'More than one previous cache executions seen when attempting '
                    'reuse_node_outputs: %s', prev_cache_executions)

            if (prev_cache_executions[-1].last_known_state ==
                    metadata_store_pb2.Execution.CACHED):
                return
            else:
                new_execution = prev_cache_executions[-1]

        output_artifacts = execution_lib.get_artifacts_dict(
            self._mlmd,
            existing_execution.id,
            event_types=list(event_lib.VALID_OUTPUT_EVENT_TYPES))

        execution_publish_utils.publish_cached_execution(
            self._mlmd,
            contexts=cached_execution_contexts,
            execution_id=new_execution.id,
            output_artifacts=output_artifacts)