Пример #1
0
def _build_proto_artifact_dict(artifact_dict):
    """Build PythonExecutorExecutionInfo input/output artifact dicts."""
    proto_dict = {}
    for k, v in artifact_dict.items():
        artifact_list = metadata_store_service_pb2.ArtifactStructList()
        for artifact in v:
            artifact_struct = metadata_store_service_pb2.ArtifactStruct(
                artifact=metadata_store_service_pb2.ArtifactAndType(
                    artifact=artifact.mlmd_artifact,
                    type=artifact.artifact_type))
            artifact_list.elements.append(artifact_struct)
        proto_dict[k] = artifact_list
    return proto_dict
Пример #2
0
def _build_proto_artifact_dict(
    artifact_dict: Mapping[str, Iterable[types.Artifact]]
) -> Dict[str, metadata_store_service_pb2.ArtifactStructList]:
  """Builds PythonExecutorExecutionInfo input/output artifact dicts."""
  result = {}
  if not artifact_dict:
    return result
  for k, v in artifact_dict.items():
    artifact_list = metadata_store_service_pb2.ArtifactStructList()
    for artifact in v:
      artifact_struct = metadata_store_service_pb2.ArtifactStruct(
          artifact=metadata_store_service_pb2.ArtifactAndType(
              artifact=artifact.mlmd_artifact, type=artifact.artifact_type))
      artifact_list.elements.append(artifact_struct)
    result[k] = artifact_list
  return result
Пример #3
0
    def get_qualified_artifacts(
        self,
        contexts: List[metadata_store_pb2.Context],
        type_name: Text,
        producer_component_id: Optional[Text] = None,
        output_key: Optional[Text] = None,
    ) -> List[metadata_store_service_pb2.ArtifactAndType]:
        """Gets qualified artifacts that have the right producer info.

    Args:
      contexts: context constraints to filter artifacts
      type_name: type constraint to filter artifacts
      producer_component_id: producer constraint to filter artifacts
      output_key: output key constraint to filter artifacts

    Returns:
      A list of ArtifactAndType, containing qualified artifacts.
    """
        def _match_producer_component_id(execution, component_id):
            if component_id:
                return execution.properties[
                    _EXECUTION_TYPE_KEY_COMPONENT_ID].string_value == component_id
            else:
                return True

        def _match_output_key(event, key):
            if key:
                assert len(
                    event.path.steps) == 2, 'Event must have two path steps.'
                return (event.type == metadata_store_pb2.Event.OUTPUT
                        and event.path.steps[0].key == key)
            else:
                return event.type == metadata_store_pb2.Event.OUTPUT

        try:
            artifact_type = self.store.get_artifact_type(type_name)
            if not artifact_type:
                raise tf.errors.NotFoundError(
                    None, None, 'No artifact type found for %s.' % type_name)
        except tf.errors.NotFoundError:
            return []

        # Gets the executions that are associated with all contexts.
        assert contexts, 'Must have at least one context.'
        executions_dict = {}
        for context in contexts:
            executions = self.store.get_executions_by_context(context.id)
            executions_dict.update(dict((e.id, e) for e in executions))

        executions_within_context = executions_dict.values()

        # Filters the executions to match producer component id.
        qualified_producer_executions = [
            e.id for e in executions_within_context
            if _match_producer_component_id(e, producer_component_id)
        ]
        # Gets the output events that have the matched output key.
        qualified_output_events = [
            ev for ev in self.store.get_events_by_execution_ids(
                qualified_producer_executions)
            if _match_output_key(ev, output_key)
        ]

        # Gets the candidate artifacts from output events.
        candidate_artifacts = self.store.get_artifacts_by_id(
            list(set(ev.artifact_id for ev in qualified_output_events)))
        # Filters the artifacts that have the right artifact type and state.
        qualified_artifacts = [
            a for a in candidate_artifacts if a.type_id == artifact_type.id
            and self._get_artifact_state(a) == ArtifactState.PUBLISHED
        ]
        return [
            metadata_store_service_pb2.ArtifactAndType(artifact=a,
                                                       type=artifact_type)
            for a in qualified_artifacts
        ]
Пример #4
0
  def get_qualified_artifacts(
      self,
      context: metadata_store_pb2.Context,
      type_name: Text,
      producer_component_id: Optional[Text] = None,
      output_key: Optional[Text] = None,
  ) -> List[metadata_store_service_pb2.ArtifactAndType]:
    """Gets qualified artifacts that have the right producer info.

    Args:
      context: context constraint to filter artifacts
      type_name: type constraint to filter artifacts
      producer_component_id: producer constraint to filter artifacts
      output_key: output key constraint to filter artifacts

    Returns:
      A list of ArtifactAndType, containing qualified artifacts.
    """

    def _match_producer_component_id(execution, component_id):
      if component_id:
        return execution.properties['component_id'].string_value == component_id
      else:
        return True

    def _match_output_key(event, key):
      if key:
        assert len(event.path.steps) == 2, 'Event must have two path steps.'
        return (event.type == metadata_store_pb2.Event.OUTPUT and
                event.path.steps[0].key == key)
      else:
        return event.type == metadata_store_pb2.Event.OUTPUT

    try:
      artifact_type = self.store.get_artifact_type(type_name)
      if not artifact_type:
        raise tf.errors.NotFoundError(
            None, None, 'No artifact type found for %s.' % type_name)
    except tf.errors.NotFoundError:
      return []

    executions_within_context = self.store.get_executions_by_context(context.id)
    qualified_producer_executions = [
        e.id
        for e in executions_within_context
        if _match_producer_component_id(e, producer_component_id)
    ]
    qualified_output_events = [
        ev for ev in self.store.get_events_by_execution_ids(
            qualified_producer_executions) if _match_output_key(ev, output_key)
    ]

    candidate_artifacts = self.store.get_artifacts_by_id(
        list(set(ev.artifact_id for ev in qualified_output_events)))
    qualified_artifacts = [
        a for a in candidate_artifacts if a.type_id == artifact_type.id and
        self._get_artifact_state(a) == ArtifactState.PUBLISHED
    ]
    return [
        metadata_store_service_pb2.ArtifactAndType(
            artifact=a, type=artifact_type) for a in qualified_artifacts
    ]