def _get_one_hop_executions( store: mlmd.MetadataStore, artifact_ids: Iterable[int], direction: _Direction, filter_type: Optional[metadata_store_pb2.ExecutionType] = None ) -> List[metadata_store_pb2.Execution]: """Gets a list of executions within 1-hop neighborhood of the `artifact_ids`. Args: store: A ml-metadata MetadataStore to look for neighborhood executions. artifact_ids: The artifacts' ids in the `store`. direction: A direction to specify whether returning ancestors or successors. filter_type: An optional type filter of the returned executions, if given then only executions of that type is returned. Returns: A list of qualified executions within 1-hop neighborhood in the `store`. """ if direction == _Direction.ANCESTOR: traverse_event = metadata_store_pb2.Event.OUTPUT elif direction == _Direction.SUCCESSOR: traverse_event = metadata_store_pb2.Event.INPUT executions_ids = set( event.execution_id for event in store.get_events_by_artifact_ids(artifact_ids) if event.type == traverse_event) return [ execution for execution in store.get_executions_by_id(executions_ids) if not filter_type or execution.type_id == filter_type.id ]
def _get_input_examples_artifacts( self, store: mlmd.MetadataStore, execution_type: Text) -> List[metadata_store_pb2.Artifact]: executions = store.get_executions_by_type(execution_type) # Get latest execution. execution = max(executions, key=lambda a: a.id) events = store.get_events_by_execution_ids([execution.id]) artifact_ids = [] for event in events: for step in event.path.steps: if step.key == 'examples': artifact_ids.append(event.artifact_id) break return store.get_artifacts_by_id(artifact_ids)
def _validate_model_id(store: mlmd.MetadataStore, model_type: metadata_store_pb2.ArtifactType, model_id: int) -> metadata_store_pb2.Artifact: """Validates the given `model_id` against the `store`. Args: store: A ml-metadata MetadataStore to be validated. model_type: The Model ArtifactType in the `store`. model_id: The id for the model artifact in the `store`. Returns: The model artifact with the id. Raises: ValueError: If the `model_id` cannot be resolved as a Model artifact in the given `store`. """ model_artifacts = store.get_artifacts_by_id([model_id]) if not model_artifacts: raise ValueError(f'Input model_id cannot be found: {model_id}.') model = model_artifacts[0] if model.type_id != model_type.id: raise ValueError( f'Found artifact with `model_id` is not an instance of Model: {model}.' ) return model
def _get_one_hop_artifacts( store: mlmd.MetadataStore, artifact_ids: Iterable[int], direction: _Direction, filter_type: Optional[metadata_store_pb2.ArtifactType] = None ) -> List[metadata_store_pb2.Artifact]: """Gets a list of artifacts within 1-hop neighborhood of the `artifact_ids`. Args: store: A ml-metadata MetadataStore to look for neighborhood artifacts. artifact_ids: The artifacts' ids in the `store`. direction: A direction to specify whether returning ancestors or successors. filter_type: An optional type filter of the returned artifacts, if given then only artifacts of that type is returned. Returns: A list of qualified artifacts within 1-hop neighborhood in the `store`. """ traverse_events = {} if direction == _Direction.ANCESTOR: traverse_events['execution'] = ( metadata_store_pb2.Event.OUTPUT, metadata_store_pb2.Event.DECLARED_OUTPUT) traverse_events['artifact'] = (metadata_store_pb2.Event.INPUT, metadata_store_pb2.Event.DECLARED_INPUT) elif direction == _Direction.SUCCESSOR: traverse_events['execution'] = ( metadata_store_pb2.Event.INPUT, metadata_store_pb2.Event.DECLARED_INPUT) traverse_events['artifact'] = ( metadata_store_pb2.Event.OUTPUT, metadata_store_pb2.Event.DECLARED_OUTPUT) executions_ids = set( event.execution_id for event in store.get_events_by_artifact_ids(artifact_ids) if event.type in traverse_events['execution']) artifacts_ids = set( event.artifact_id for event in store.get_events_by_execution_ids(executions_ids) if event.type in traverse_events['artifact']) return [ artifact for artifact in store.get_artifacts_by_id(artifacts_ids) if not filter_type or artifact.type_id == filter_type.id ]
def _get_tfx_pipeline_types(store: mlmd.MetadataStore) -> _PipelineTypes: """Retrieves the registered types in the given `store`. Args: store: A ml-metadata MetadataStore to retrieve ArtifactTypes from. Returns: A instance of _PipelineTypes containing store pipeline types. Raises: ValueError: If the `store` does not have MCT related types and is not considered a valid TFX store. """ artifact_types = { atype.name: atype for atype in store.get_artifact_types() } expected_artifact_types = { _TFX_DATASET_TYPE, _TFX_STATS_TYPE, _TFX_MODEL_TYPE, _TFX_METRICS_TYPE } missing_types = expected_artifact_types.difference(artifact_types.keys()) if missing_types: raise ValueError( f'Given `store` is invalid: missing ArtifactTypes: {missing_types}.' ) execution_types = { etype.name: etype for etype in store.get_execution_types() } expected_execution_types = {_TFX_TRAINER_TYPE} missing_types = expected_execution_types.difference(execution_types.keys()) if missing_types: raise ValueError( f'Given `store` is invalid: missing ExecutionTypes: {missing_types}.' ) return _PipelineTypes(dataset_type=artifact_types[_TFX_DATASET_TYPE], stats_type=artifact_types[_TFX_STATS_TYPE], model_type=artifact_types[_TFX_MODEL_TYPE], metrics_type=artifact_types[_TFX_METRICS_TYPE], trainer_type=execution_types[_TFX_TRAINER_TYPE])