Exemplo n.º 1
0
 def test_channel_utils_as_channel_success(self):
     instance_a = Artifact('MyTypeName')
     instance_b = Artifact('MyTypeName')
     chnl_original = Channel('MyTypeName',
                             artifacts=[instance_a, instance_b])
     chnl_result = channel_utils.as_channel(chnl_original)
     self.assertEqual(chnl_original, chnl_result)
Exemplo n.º 2
0
 def testUnwrapChannelDict(self):
     instance_a = Artifact('MyTypeName')
     instance_b = Artifact('MyTypeName')
     channel_dict = {
         'id': Channel('MyTypeName', artifacts=[instance_a, instance_b])
     }
     result = channel_utils.unwrap_channel_dict(channel_dict)
     self.assertDictEqual(result, {'id': [instance_a, instance_b]})
Exemplo n.º 3
0
def deserialize_artifact(
        artifact_type: metadata_store_pb2.ArtifactType,
        artifact: Optional[metadata_store_pb2.Artifact] = None) -> Artifact:
    """Reconstruct Artifact object from MLMD proto descriptors.

  Internal method, no backwards compatibility guarantees.

  Args:
    artifact_type: A metadata_store_pb2.ArtifactType proto object describing
      the type of the artifact.
    artifact: A metadata_store_pb2.Artifact proto object describing the
      contents of the artifact.  If not provided, an Artifact of the desired
      type with empty contents is created.

  Returns:
    Artifact subclass object for the given MLMD proto descriptors.
  """
    # Validate inputs.
    if not isinstance(artifact_type, metadata_store_pb2.ArtifactType):
        raise ValueError((
            'Expected metadata_store_pb2.ArtifactType for artifact_type, got %s '
            'instead') % (artifact_type, ))
    if artifact and not isinstance(artifact, metadata_store_pb2.Artifact):
        raise ValueError(
            ('Expected metadata_store_pb2.Artifact for artifact, got %s '
             'instead') % (artifact, ))

    # Make sure this module path containing the standard Artifact subclass
    # definitions is imported. Modules containing custom artifact subclasses that
    # need to be deserialized should be imported by the entrypoint of the
    # application or container.
    from tfx.types import standard_artifacts  # pylint: disable=g-import-not-at-top,unused-variable

    # Attempt to find the appropriate Artifact subclass for reconstructing this
    # object.
    artifact_cls = None
    for cls in Artifact.__subclasses__():
        if cls.TYPE_NAME == artifact_type.name:
            artifact_cls = cls

    # Construct the Artifact object, using a concrete Artifact subclass when
    # possible.
    if artifact_cls:
        result = artifact_cls()
        result.set_mlmd_artifact_type(artifact_type)
    else:
        absl.logging.warning((
            'Could not load artifact class for type %r; using fallback '
            'deserialization for the relevant artifact. If this is not intended, '
            'please make sure that the artifact class for this type can be '
            'imported within your container or environment where a component is '
            'executed to consume this type.') % (artifact_type.name))
        result = Artifact(mlmd_artifact_type=artifact_type)
    if artifact:
        result.set_mlmd_artifact(artifact)
    return result
Exemplo n.º 4
0
  def _get_outputs_of_execution(
      self, desired_input_ids: Set[int], execution_id: int,
      events: List[metadata_store_pb2.Event]
  ) -> Optional[Dict[Text, List[Artifact]]]:
    """Fetches outputs produced by a historical execution with desired inputs.

    If the desired input ids are not exactly the same as the input artifacts
    of the given execution id, return nothing. Otherwise, return the output
    artifacts in the format of key -> List[Artifact].

    Args:
      desired_input_ids: artifact ids of desired inputs.
      execution_id: the id of the execution that produced the outputs.
      events: events related to the execution id.

    Returns:
      A dict of key -> List[Artifact] as the result
    """

    execution_input_ids = set(event.artifact_id
                              for event in events
                              if event.type == metadata_store_pb2.Event.INPUT)
    # Only needs to compare the length of the input ids set since we only need
    # to rule out the case that past execution uses more inputs than given
    # inputs.
    if len(desired_input_ids) != len(execution_input_ids):
      absl.logging.debug('Execution %s does not match all inputs' %
                         execution_id)
      return None

    absl.logging.debug('Execution %s matches all inputs' % execution_id)
    result = collections.defaultdict(list)

    output_events = [
        event for event in events
        if event.type in [metadata_store_pb2.Event.OUTPUT]
    ]
    output_events.sort(key=lambda e: e.path.steps[1].index)
    cached_output_artifacts = self.store.get_artifacts_by_id(
        [e.artifact_id for e in output_events])
    artifact_types = self.store.get_artifact_types_by_id(
        [a.type_id for a in cached_output_artifacts])

    for event, mlmd_artifact, artifact_type in zip(output_events,
                                                   cached_output_artifacts,
                                                   artifact_types):
      key = event.path.steps[0].key
      tfx_artifact = Artifact(mlmd_artifact_type=artifact_type)
      tfx_artifact.set_mlmd_artifact(mlmd_artifact)
      result[key].append(tfx_artifact)

    return result
Exemplo n.º 5
0
def _get_data_view_info(
        examples: artifact.Artifact) -> Optional[Tuple[str, int]]:
    """Returns the payload format and data view URI and ID from examples."""
    assert examples.type is standard_artifacts.Examples, (
        'examples must be of type standard_artifacts.Examples')
    payload_format = examples_utils.get_payload_format(examples)
    if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PROTO:
        data_view_uri = examples.get_string_custom_property(
            constants.DATA_VIEW_URI_PROPERTY_KEY)
        if data_view_uri:
            data_view_create_time = examples.get_int_custom_property(
                constants.DATA_VIEW_CREATE_TIME_KEY)
            return data_view_uri, data_view_create_time

    return None
Exemplo n.º 6
0
    def search_artifacts(self, artifact_name: Text, pipeline_name: Text,
                         run_id: Text,
                         producer_component_id: Text) -> List[Artifact]:
        """Search artifacts that matches given info.

    Args:
      artifact_name: the name of the artifact that set by producer component.
        The name is logged both in artifacts and the events when the execution
        being published.
      pipeline_name: the name of the pipeline that produces the artifact
      run_id: the run id of the pipeline run that produces the artifact
      producer_component_id: the id of the component that produces the artifact

    Returns:
      A list of Artifacts that matches the given info

    Raises:
      RuntimeError: when no matching execution is found given producer info.
    """
        producer_execution = None
        matching_artifact_ids = set()
        for execution in self._store.get_executions():
            if (execution.properties['pipeline_name'].string_value
                    == pipeline_name
                    and execution.properties['run_id'].string_value == run_id
                    and execution.properties['component_id'].string_value
                    == producer_component_id):
                producer_execution = execution
        if not producer_execution:
            raise RuntimeError(
                'Cannot find matching execution with pipeline name %s,'
                'run id %s and component id %s' %
                (pipeline_name, run_id, producer_component_id))
        for event in self._store.get_events_by_execution_ids(
            [producer_execution.id]):
            if (event.type == metadata_store_pb2.Event.OUTPUT
                    and event.path.steps[0].key == artifact_name):
                matching_artifact_ids.add(event.artifact_id)

        result_artifacts = []
        for a in self._store.get_artifacts_by_id(list(matching_artifact_ids)):
            tfx_artifact = Artifact(a.properties['type_name'].string_value)
            tfx_artifact.artifact = a
            result_artifacts.append(tfx_artifact)
        return result_artifacts
Exemplo n.º 7
0
def is_artifact_version_older_than(artifact: Artifact,
                                   artifact_version: Text) -> bool:
  """Check if artifact belongs to old version."""
  if artifact.mlmd_artifact.state == metadata_store_pb2.Artifact.UNKNOWN:
    # Newly generated artifact should use the latest artifact payload format.
    return False

  # For artifact that resolved from MLMD.
  if not artifact.has_custom_property(ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY):
    # Artifact without version.
    return True

  if (version.parse(
      artifact.get_string_custom_property(
          ARTIFACT_TFX_VERSION_CUSTOM_PROPERTY_KEY)) <
      version.parse(artifact_version)):
    # Artifact with old version.
    return True
  else:
    return False
Exemplo n.º 8
0
 def from_json_dict(cls, dict_data: Dict[Text, Any]) -> Any:
     artifact_type = metadata_store_pb2.ArtifactType()
     json_format.Parse(json.dumps(dict_data['type']), artifact_type)
     type_cls = artifact_utils.get_artifact_type_class(artifact_type)
     artifacts = list(
         Artifact.from_json_dict(a) for a in dict_data['artifacts'])
     producer_component_id = dict_data.get('producer_component_id', None)
     output_key = dict_data.get('output_key', None)
     return Channel(type=type_cls,
                    artifacts=artifacts,
                    producer_component_id=producer_component_id,
                    output_key=output_key)
Exemplo n.º 9
0
def get_tfxio_factory_from_artifact(
    examples: artifact.Artifact,
    telemetry_descriptors: List[Text],
    schema: Optional[schema_pb2.Schema] = None,
    read_as_raw_records: bool = False,
    raw_record_column_name: Optional[Text] = None
) -> Callable[[Text], tfxio.TFXIO]:
    """Returns a factory function that creates a proper TFXIO.

  Args:
    examples: The Examples artifact that the TFXIO is intended to access.
    telemetry_descriptors: A set of descriptors that identify the component
      that is instantiating the TFXIO. These will be used to construct the
      namespace to contain metrics for profiling and are therefore expected to
      be identifiers of the component itself and not individual instances of
      source use.
    schema: TFMD schema. Note that without a schema, some TFXIO interfaces
      in certain TFXIO implementations might not be available.
    read_as_raw_records: If True, ignore the payload type of `examples`. Always
      use RawTfRecord TFXIO.
    raw_record_column_name: If provided, the arrow RecordBatch produced by
      the TFXIO will contain a string column of the given name, and the contents
      of that column will be the raw records. Note that not all TFXIO supports
      this option, and an error will be raised in that case. Required if
      read_as_raw_records == True.

  Returns:
    A function that takes a file pattern as input and returns a TFXIO
    instance.

  Raises:
    NotImplementedError: when given an unsupported example payload type.
  """
    assert examples.type is standard_artifacts.Examples, (
        'examples must be of type standard_artifacts.Examples')
    # In case that the payload format custom property is not set.
    # Assume tf.Example.
    payload_format = examples_utils.get_payload_format(examples)
    data_view_uri = None
    if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PROTO:
        data_view_uri = examples.get_string_custom_property(
            constants.DATA_VIEW_URI_PROPERTY_KEY)
        if not data_view_uri:
            data_view_uri = None
    return lambda file_pattern: make_tfxio(  # pylint:disable=g-long-lambda
        file_pattern=file_pattern,
        telemetry_descriptors=telemetry_descriptors,
        payload_format=payload_format,
        data_view_uri=data_view_uri,
        schema=schema,
        read_as_raw_records=read_as_raw_records,
        raw_record_column_name=raw_record_column_name)
Exemplo n.º 10
0
  def search_artifacts(self, artifact_name: Text,
                       pipeline_info: data_types.PipelineInfo,
                       producer_component_id: Text) -> List[Artifact]:
    """Search artifacts that matches given info.

    Args:
      artifact_name: the name of the artifact that set by producer component.
        The name is logged both in artifacts and the events when the execution
        being published.
      pipeline_info: the information of the current pipeline
      producer_component_id: the id of the component that produces the artifact

    Returns:
      A list of Artifacts that matches the given info

    Raises:
      RuntimeError: when no matching execution is found given producer info.
    """
    producer_execution = None
    matching_artifact_ids = set()
    # TODO(ruoyu): We need to revisit this when adding support for async
    # execution.
    context = self.get_pipeline_run_context(pipeline_info)
    if context is None:
      raise RuntimeError('Pipeline run context for %s does not exist' %
                         pipeline_info)
    for execution in self.store.get_executions_by_context(context.id):
      if execution.properties[
          'component_id'].string_value == producer_component_id:
        producer_execution = execution
        break
    if not producer_execution:
      raise RuntimeError('Cannot find matching execution with pipeline name %s,'
                         'run id %s and component id %s' %
                         (pipeline_info.pipeline_name, pipeline_info.run_id,
                          producer_component_id))
    for event in self.store.get_events_by_execution_ids([producer_execution.id
                                                        ]):
      if (event.type == metadata_store_pb2.Event.OUTPUT and
          event.path.steps[0].key == artifact_name):
        matching_artifact_ids.add(event.artifact_id)

    # Get relevant artifacts along with their types.
    artifacts_by_id = self.store.get_artifacts_by_id(
        list(matching_artifact_ids))
    matching_artifact_type_ids = list(set(a.type_id for a in artifacts_by_id))
    matching_artifact_types = self.store.get_artifact_types_by_id(
        matching_artifact_type_ids)
    artifact_types = dict(
        zip(matching_artifact_type_ids, matching_artifact_types))

    result_artifacts = []
    for a in artifacts_by_id:
      tfx_artifact = Artifact(mlmd_artifact_type=artifact_types[a.type_id])
      tfx_artifact.set_mlmd_artifact(a)
      tfx_artifact.set_mlmd_artifact_type(artifact_types[a.type_id])
      result_artifacts.append(tfx_artifact)
    return result_artifacts
Exemplo n.º 11
0
def refactor_model_blessing(model_blessing: artifact.Artifact,
                            name_from_id: Mapping[int, str]) -> None:
    """Changes id-typed custom properties to string-typed runtime artifact name."""
    if model_blessing.has_custom_property(
            constants.ARTIFACT_PROPERTY_BASELINE_MODEL_ID_KEY):
        model_blessing.set_string_custom_property(
            constants.ARTIFACT_PROPERTY_BASELINE_MODEL_ID_KEY,
            _get_full_name(artifact_id=model_blessing.get_int_custom_property(
                constants.ARTIFACT_PROPERTY_BASELINE_MODEL_ID_KEY),
                           name_from_id=name_from_id))
    if model_blessing.has_custom_property(
            constants.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY):
        model_blessing.set_string_custom_property(
            constants.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY,
            _get_full_name(artifact_id=model_blessing.get_int_custom_property(
                constants.ARTIFACT_PROPERTY_CURRENT_MODEL_ID_KEY),
                           name_from_id=name_from_id))
Exemplo n.º 12
0
def parse_artifact_dict(json_str: Text) -> Dict[Text, List[Artifact]]:
    """Parse a dict from key to list of Artifact from its json format."""
    tfx_artifacts = {}
    for k, l in json.loads(json_str).items():
        tfx_artifacts[k] = [Artifact.from_json_dict(v) for v in l]
    return tfx_artifacts
Exemplo n.º 13
0
 def testArtifactCollectionAsChannel(self):
     instance_a = Artifact('MyTypeName')
     instance_b = Artifact('MyTypeName')
     chnl = channel_utils.as_channel([instance_a, instance_b])
     self.assertEqual(chnl.type_name, 'MyTypeName')
     self.assertItemsEqual(chnl.get(), [instance_a, instance_b])
Exemplo n.º 14
0
 def test_invalid_channel_type(self):
   instance_a = Artifact('MyTypeName')
   instance_b = Artifact('MyTypeName')
   with self.assertRaises(ValueError):
     Channel('AnotherTypeName', artifacts=[instance_a, instance_b])
Exemplo n.º 15
0
 def test_valid_channel(self):
   instance_a = Artifact('MyTypeName')
   instance_b = Artifact('MyTypeName')
   chnl = Channel('MyTypeName', artifacts=[instance_a, instance_b])
   self.assertEqual(chnl.type_name, 'MyTypeName')
   self.assertItemsEqual(chnl.get(), [instance_a, instance_b])