Exemplo n.º 1
0
def _get_one_hop_executions(
    store: mlmd.MetadataStore,
    artifact_ids: Iterable[int],
    direction: _Direction,
    filter_type: Optional[metadata_store_pb2.ExecutionType] = None
) -> List[metadata_store_pb2.Execution]:
    """Gets a list of executions within 1-hop neighborhood of the `artifact_ids`.

  Args:
    store: A ml-metadata MetadataStore to look for neighborhood executions.
    artifact_ids: The artifacts' ids in the `store`.
    direction: A direction to specify whether returning ancestors or successors.
    filter_type: An optional type filter of the returned executions, if given
      then only executions of that type is returned.

  Returns:
    A list of qualified executions within 1-hop neighborhood in the `store`.
  """
    if direction == _Direction.ANCESTOR:
        traverse_event = metadata_store_pb2.Event.OUTPUT
    elif direction == _Direction.SUCCESSOR:
        traverse_event = metadata_store_pb2.Event.INPUT
    executions_ids = set(
        event.execution_id
        for event in store.get_events_by_artifact_ids(artifact_ids)
        if event.type == traverse_event)
    return [
        execution for execution in store.get_executions_by_id(executions_ids)
        if not filter_type or execution.type_id == filter_type.id
    ]
 def _get_input_examples_artifacts(
         self, store: mlmd.MetadataStore,
         execution_type: Text) -> List[metadata_store_pb2.Artifact]:
     executions = store.get_executions_by_type(execution_type)
     # Get latest execution.
     execution = max(executions, key=lambda a: a.id)
     events = store.get_events_by_execution_ids([execution.id])
     artifact_ids = []
     for event in events:
         for step in event.path.steps:
             if step.key == 'examples':
                 artifact_ids.append(event.artifact_id)
                 break
     return store.get_artifacts_by_id(artifact_ids)
Exemplo n.º 3
0
def _validate_model_id(store: mlmd.MetadataStore,
                       model_type: metadata_store_pb2.ArtifactType,
                       model_id: int) -> metadata_store_pb2.Artifact:
    """Validates the given `model_id` against the `store`.

  Args:
    store: A ml-metadata MetadataStore to be validated.
    model_type: The Model ArtifactType in the `store`.
    model_id: The id for the model artifact in the `store`.

  Returns:
    The model artifact with the id.

  Raises:
    ValueError: If the `model_id` cannot be resolved as a Model artifact in the
      given `store`.
  """
    model_artifacts = store.get_artifacts_by_id([model_id])
    if not model_artifacts:
        raise ValueError(f'Input model_id cannot be found: {model_id}.')
    model = model_artifacts[0]
    if model.type_id != model_type.id:
        raise ValueError(
            f'Found artifact with `model_id` is not an instance of Model: {model}.'
        )
    return model
Exemplo n.º 4
0
def _get_one_hop_artifacts(
    store: mlmd.MetadataStore,
    artifact_ids: Iterable[int],
    direction: _Direction,
    filter_type: Optional[metadata_store_pb2.ArtifactType] = None
) -> List[metadata_store_pb2.Artifact]:
    """Gets a list of artifacts within 1-hop neighborhood of the `artifact_ids`.

  Args:
    store: A ml-metadata MetadataStore to look for neighborhood artifacts.
    artifact_ids: The artifacts' ids in the `store`.
    direction: A direction to specify whether returning ancestors or successors.
    filter_type: An optional type filter of the returned artifacts, if given
      then only artifacts of that type is returned.

  Returns:
    A list of qualified artifacts within 1-hop neighborhood in the `store`.
  """
    traverse_events = {}
    if direction == _Direction.ANCESTOR:
        traverse_events['execution'] = (
            metadata_store_pb2.Event.OUTPUT,
            metadata_store_pb2.Event.DECLARED_OUTPUT)
        traverse_events['artifact'] = (metadata_store_pb2.Event.INPUT,
                                       metadata_store_pb2.Event.DECLARED_INPUT)
    elif direction == _Direction.SUCCESSOR:
        traverse_events['execution'] = (
            metadata_store_pb2.Event.INPUT,
            metadata_store_pb2.Event.DECLARED_INPUT)
        traverse_events['artifact'] = (
            metadata_store_pb2.Event.OUTPUT,
            metadata_store_pb2.Event.DECLARED_OUTPUT)
    executions_ids = set(
        event.execution_id
        for event in store.get_events_by_artifact_ids(artifact_ids)
        if event.type in traverse_events['execution'])
    artifacts_ids = set(
        event.artifact_id
        for event in store.get_events_by_execution_ids(executions_ids)
        if event.type in traverse_events['artifact'])
    return [
        artifact for artifact in store.get_artifacts_by_id(artifacts_ids)
        if not filter_type or artifact.type_id == filter_type.id
    ]
Exemplo n.º 5
0
def _get_tfx_pipeline_types(store: mlmd.MetadataStore) -> _PipelineTypes:
    """Retrieves the registered types in the given `store`.

  Args:
    store: A ml-metadata MetadataStore to retrieve ArtifactTypes from.

  Returns:
    A instance of _PipelineTypes containing store pipeline types.

  Raises:
    ValueError: If the `store` does not have MCT related types and is not
      considered a valid TFX store.
  """
    artifact_types = {
        atype.name: atype
        for atype in store.get_artifact_types()
    }
    expected_artifact_types = {
        _TFX_DATASET_TYPE, _TFX_STATS_TYPE, _TFX_MODEL_TYPE, _TFX_METRICS_TYPE
    }
    missing_types = expected_artifact_types.difference(artifact_types.keys())
    if missing_types:
        raise ValueError(
            f'Given `store` is invalid: missing ArtifactTypes: {missing_types}.'
        )
    execution_types = {
        etype.name: etype
        for etype in store.get_execution_types()
    }
    expected_execution_types = {_TFX_TRAINER_TYPE}
    missing_types = expected_execution_types.difference(execution_types.keys())
    if missing_types:
        raise ValueError(
            f'Given `store` is invalid: missing ExecutionTypes: {missing_types}.'
        )
    return _PipelineTypes(dataset_type=artifact_types[_TFX_DATASET_TYPE],
                          stats_type=artifact_types[_TFX_STATS_TYPE],
                          model_type=artifact_types[_TFX_MODEL_TYPE],
                          metrics_type=artifact_types[_TFX_METRICS_TYPE],
                          trainer_type=execution_types[_TFX_TRAINER_TYPE])