コード例 #1
0
ファイル: examples_utils_test.py プロジェクト: jay90099/tfx
    def test_get_payload_format(self):
        examples = standard_artifacts.Examples()
        self.assertEqual(examples_utils.get_payload_format(examples),
                         example_gen_pb2.PayloadFormat.FORMAT_TF_EXAMPLE)
        self.assertEqual(examples_utils.get_payload_format_string(examples),
                         'FORMAT_TF_EXAMPLE')

        examples.set_string_custom_property(utils.PAYLOAD_FORMAT_PROPERTY_NAME,
                                            'FORMAT_PROTO')
        self.assertEqual(examples_utils.get_payload_format(examples),
                         example_gen_pb2.PayloadFormat.FORMAT_PROTO)
        self.assertEqual(examples_utils.get_payload_format_string(examples),
                         'FORMAT_PROTO')
コード例 #2
0
def _get_payload_format(examples: List[artifact.Artifact]) -> int:
  payload_formats = set(
      [examples_utils.get_payload_format(e) for e in examples])
  if len(payload_formats) != 1:
    raise ValueError('Unable to read example artifacts of different payload '
                     'formats: {}'.format(payload_formats))
  return payload_formats.pop()
コード例 #3
0
ファイル: request_builder.py プロジェクト: vikrosj/tfx
    def ReadExamplesArtifact(self,
                             examples: types.Artifact,
                             num_examples: int,
                             split_name: Optional[Text] = None):
        """Read records from Examples artifact.

    Currently it assumes Examples artifact contains serialized tf.Example in
    gzipped TFRecord files.

    Args:
      examples: `Examples` artifact.
      num_examples: Number of examples to read. If the specified value is larger
          than the actual number of examples, all examples would be read.
      split_name: Name of the split to read from the Examples artifact.

    Raises:
      RuntimeError: If read twice.
    """
        if self._records:
            raise RuntimeError('Cannot read records twice.')

        if num_examples < 1:
            raise ValueError('num_examples < 1 (got {})'.format(num_examples))

        available_splits = artifact_utils.decode_split_names(
            examples.split_names)
        if not available_splits:
            raise ValueError(
                'No split_name is available in given Examples artifact.')
        if split_name is None:
            split_name = available_splits[0]
        if split_name not in available_splits:
            raise ValueError(
                'No split_name {}; available split names: {}'.format(
                    split_name, ', '.join(available_splits)))

        # ExampleGen generates artifacts under each split_name directory.
        glob_pattern = os.path.join(examples.uri, split_name, '*')
        tfxio_factory = tfxio_utils.get_tfxio_factory_from_artifact(
            examples=[examples],
            telemetry_descriptors=_TELEMETRY_DESCRIPTORS,
            schema=None,
            read_as_raw_records=True,
            raw_record_column_name=_RAW_RECORDS_COLUMN)
        try:
            filenames = fileio.glob(glob_pattern)
        except tf.errors.NotFoundError:
            filenames = []
        if not filenames:
            raise ValueError(
                'Unable to find examples matching {}.'.format(glob_pattern))

        self._payload_format = examples_utils.get_payload_format(examples)
        tfxio = tfxio_factory(filenames)

        self._ReadFromDataset(
            tfxio.TensorFlowDataset(
                dataset_options.TensorFlowDatasetOptions(
                    batch_size=num_examples)))
コード例 #4
0
ファイル: tfxio_utils.py プロジェクト: Mistobaan/tfx
def get_tfxio_factory_from_artifact(
    examples: artifact.Artifact,
    telemetry_descriptors: List[Text],
    schema: Optional[schema_pb2.Schema] = None,
    read_as_raw_records: bool = False,
    raw_record_column_name: Optional[Text] = None
) -> Callable[[Text], tfxio.TFXIO]:
    """Returns a factory function that creates a proper TFXIO.

  Args:
    examples: The Examples artifact that the TFXIO is intended to access.
    telemetry_descriptors: A set of descriptors that identify the component
      that is instantiating the TFXIO. These will be used to construct the
      namespace to contain metrics for profiling and are therefore expected to
      be identifiers of the component itself and not individual instances of
      source use.
    schema: TFMD schema. Note that without a schema, some TFXIO interfaces
      in certain TFXIO implementations might not be available.
    read_as_raw_records: If True, ignore the payload type of `examples`. Always
      use RawTfRecord TFXIO.
    raw_record_column_name: If provided, the arrow RecordBatch produced by
      the TFXIO will contain a string column of the given name, and the contents
      of that column will be the raw records. Note that not all TFXIO supports
      this option, and an error will be raised in that case. Required if
      read_as_raw_records == True.

  Returns:
    A function that takes a file pattern as input and returns a TFXIO
    instance.

  Raises:
    NotImplementedError: when given an unsupported example payload type.
  """
    assert examples.type is standard_artifacts.Examples, (
        'examples must be of type standard_artifacts.Examples')
    # In case that the payload format custom property is not set.
    # Assume tf.Example.
    payload_format = examples_utils.get_payload_format(examples)
    data_view_uri = None
    if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PROTO:
        data_view_uri = examples.get_string_custom_property(
            constants.DATA_VIEW_URI_PROPERTY_KEY)
        if not data_view_uri:
            data_view_uri = None
    return lambda file_pattern: make_tfxio(  # pylint:disable=g-long-lambda
        file_pattern=file_pattern,
        telemetry_descriptors=telemetry_descriptors,
        payload_format=payload_format,
        data_view_uri=data_view_uri,
        schema=schema,
        read_as_raw_records=read_as_raw_records,
        raw_record_column_name=raw_record_column_name)
コード例 #5
0
ファイル: tfxio_utils.py プロジェクト: htahir1/tfx
def _get_data_view_info(
        examples: artifact.Artifact) -> Optional[Tuple[str, int]]:
    """Returns the payload format and data view URI and ID from examples."""
    assert examples.type is standard_artifacts.Examples, (
        'examples must be of type standard_artifacts.Examples')
    payload_format = examples_utils.get_payload_format(examples)
    if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PROTO:
        data_view_uri = examples.get_string_custom_property(
            constants.DATA_VIEW_URI_PROPERTY_KEY)
        if data_view_uri:
            data_view_create_time = examples.get_int_custom_property(
                constants.DATA_VIEW_CREATE_TIME_KEY)
            return data_view_uri, data_view_create_time

    return None
コード例 #6
0
ファイル: examples_utils_test.py プロジェクト: jay90099/tfx
 def test_get_payload_format_invalid_artifact_type(self):
     artifact = standard_artifacts.Schema()
     with self.assertRaises(AssertionError):
         examples_utils.get_payload_format(artifact)