Example #1
0
    def testRecordBatchAndTensorAdapter(self):
        column_name = "raw_record"
        telemetry_descriptors = ["some", "component"]
        tfxio = raw_tf_record.RawTfRecordTFXIO(
            self._raw_record_file,
            column_name,
            telemetry_descriptors=telemetry_descriptors)
        expected_type = (pa.large_list(pa.large_binary()) if
                         _ProducesLargeTypes(tfxio) else pa.list_(pa.binary()))

        got_schema = tfxio.ArrowSchema()
        self.assertTrue(
            got_schema.equals(pa.schema([pa.field(column_name,
                                                  expected_type)])),
            "got: {}".format(got_schema))

        def _AssertFn(record_batches):
            self.assertLen(record_batches, 1)
            record_batch = record_batches[0]
            self.assertTrue(record_batch.schema.equals(tfxio.ArrowSchema()))
            self.assertTrue(record_batch.columns[0].equals(
                pa.array([[r] for r in _RAW_RECORDS], type=expected_type)))
            tensor_adapter = tfxio.TensorAdapter()
            tensors = tensor_adapter.ToBatchTensors(record_batch)
            self.assertLen(tensors, 1)
            self.assertIn(column_name, tensors)

        p = beam.Pipeline()
        record_batch_pcoll = p | tfxio.BeamSource(batch_size=len(_RAW_RECORDS))
        beam_testing_util.assert_that(record_batch_pcoll, _AssertFn)
        pipeline_result = p.run()
        pipeline_result.wait_until_finish()
        telemetry_test_util.ValidateMetrics(self, pipeline_result,
                                            telemetry_descriptors, "bytes",
                                            "tfrecords_gzip")
Example #2
0
 def testTensorFlowDatasetGraphMode(self):
     column_name = "raw_record"
     tfxio = raw_tf_record.RawTfRecordTFXIO(
         self._raw_record_file,
         column_name,
         telemetry_descriptors=["some", "component"])
     actual_records = []
     with tf.compat.v1.Graph().as_default():
         ds = tfxio.TensorFlowDataset(
             dataset_options.TensorFlowDatasetOptions(
                 batch_size=1,
                 shuffle=False,
                 num_epochs=1,
                 reader_num_threads=1,
                 sloppy_ordering=False))
         iterator = tf.compat.v1.data.make_one_shot_iterator(ds)
         next_elem = iterator.get_next()
         with tf.compat.v1.Session() as sess:
             while True:
                 try:
                     actual_records.append(
                         sess.run(next_elem)[column_name][0])
                 except tf.errors.OutOfRangeError:
                     break
     self.assertEqual(actual_records, _RAW_RECORDS)
    def testProject(self):
        column_name = "raw_record"
        tfxio = raw_tf_record.RawTfRecordTFXIO(self._raw_record_file,
                                               column_name)
        projected = tfxio.Project([column_name])
        self.assertTrue(tfxio.ArrowSchema().equals(projected.ArrowSchema()))
        self.assertEqual(tfxio.TensorRepresentations(),
                         projected.TensorRepresentations())

        with self.assertRaises(AssertionError):
            tfxio.Project(["some_other_name"])
Example #4
0
 def testTensorFlowDataset(self):
     column_name = "raw_record"
     tfxio = raw_tf_record.RawTfRecordTFXIO(
         self._raw_record_file,
         column_name,
         telemetry_descriptors=["some", "component"])
     ds = tfxio.TensorFlowDataset(
         dataset_options.TensorFlowDatasetOptions(batch_size=1,
                                                  shuffle=False,
                                                  num_epochs=1,
                                                  reader_num_threads=1,
                                                  sloppy_ordering=False))
     actual_records = [d[column_name].numpy()[0] for d in ds]
     self.assertEqual(actual_records, _RAW_RECORDS)
Example #5
0
    def testRecordBatchAndTensorAdapter(self):
        column_name = "raw_record"
        tfxio = raw_tf_record.RawTfRecordTFXIO(self._raw_record_file,
                                               column_name)
        self.assertTrue(
            tfxio.ArrowSchema(),
            pa.schema([pa.field(column_name, pa.list_(pa.binary()))]))

        def _AssertFn(record_batches):
            self.assertLen(record_batches, 1)
            record_batch = record_batches[0]
            self.assertTrue(record_batch.schema.equals(tfxio.ArrowSchema()))
            self.assertTrue(record_batch.columns[0].equals(
                pa.array([[r] for r in _RAW_RECORDS],
                         type=pa.list_(pa.binary()))))
            tensor_adapter = tfxio.TensorAdapter()
            tensors = tensor_adapter.ToBatchTensors(record_batch)
            self.assertLen(tensors, 1)
            self.assertIn(column_name, tensors)

        with beam.Pipeline() as p:
            record_batch_pcoll = p | tfxio.BeamSource(
                batch_size=len(_RAW_RECORDS))
            beam_testing_util.assert_that(record_batch_pcoll, _AssertFn)
Example #6
0
def make_tfxio(
        file_pattern: OneOrMorePatterns,
        telemetry_descriptors: List[str],
        payload_format: Union[str, int],
        data_view_uri: Optional[str] = None,
        schema: Optional[schema_pb2.Schema] = None,
        read_as_raw_records: bool = False,
        raw_record_column_name: Optional[str] = None,
        file_format: Optional[Union[str, List[str]]] = None) -> tfxio.TFXIO:
    """Creates a TFXIO instance that reads `file_pattern`.

  Args:
    file_pattern: the file pattern for the TFXIO to access.
    telemetry_descriptors: A set of descriptors that identify the component that
      is instantiating the TFXIO. These will be used to construct the namespace
      to contain metrics for profiling and are therefore expected to be
      identifiers of the component itself and not individual instances of source
      use.
    payload_format: one of the enums from example_gen_pb2.PayloadFormat (may be
      in string or int form). If None, default to FORMAT_TF_EXAMPLE.
    data_view_uri: uri to a DataView artifact. A DataView is needed in order to
      create a TFXIO for certain payload formats.
    schema: TFMD schema. Note: although optional, some payload formats need a
      schema in order for all TFXIO interfaces (e.g. TensorAdapter()) to work.
      Unless you know what you are doing, always supply a schema.
    read_as_raw_records: If True, ignore the payload type of `examples`. Always
      use RawTfRecord TFXIO.
    raw_record_column_name: If provided, the arrow RecordBatch produced by the
      TFXIO will contain a string column of the given name, and the contents of
      that column will be the raw records. Note that not all TFXIO supports this
      option, and an error will be raised in that case. Required if
      read_as_raw_records == True.
    file_format: file format string for each file_pattern. Only 'tfrecords_gzip'
      is supported for now.

  Returns:
    a TFXIO instance.
  """
    if not isinstance(payload_format, int):
        payload_format = example_gen_pb2.PayloadFormat.Value(payload_format)

    if file_format is not None:
        if type(file_format) is not type(file_pattern):
            raise ValueError(
                f'The type of file_pattern and file_formats should be the same.'
                f'Given: file_pattern={file_pattern}, file_format={file_format}'
            )
        if isinstance(file_format, list):
            if len(file_format) != len(file_pattern):
                raise ValueError(
                    f'The length of file_pattern and file_formats should be the same.'
                    f'Given: file_pattern={file_pattern}, file_format={file_format}'
                )
            else:
                if any(item != 'tfrecords_gzip' for item in file_format):
                    raise NotImplementedError(
                        f'{file_format} is not supported yet.')
        else:  # file_format is str type.
            if file_format != 'tfrecords_gzip':
                raise NotImplementedError(
                    f'{file_format} is not supported yet.')

    if read_as_raw_records:
        assert raw_record_column_name is not None, (
            'read_as_raw_records is specified - '
            'must provide raw_record_column_name')
        return raw_tf_record.RawTfRecordTFXIO(
            file_pattern=file_pattern,
            raw_record_column_name=raw_record_column_name,
            telemetry_descriptors=telemetry_descriptors)

    if payload_format == example_gen_pb2.PayloadFormat.FORMAT_TF_EXAMPLE:
        return tf_example_record.TFExampleRecord(
            file_pattern=file_pattern,
            schema=schema,
            raw_record_column_name=raw_record_column_name,
            telemetry_descriptors=telemetry_descriptors)

    if (payload_format ==
            example_gen_pb2.PayloadFormat.FORMAT_TF_SEQUENCE_EXAMPLE):
        return tf_sequence_example_record.TFSequenceExampleRecord(
            file_pattern=file_pattern,
            schema=schema,
            raw_record_column_name=raw_record_column_name,
            telemetry_descriptors=telemetry_descriptors)

    if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PROTO:
        assert data_view_uri is not None, (
            'Accessing FORMAT_PROTO requires a DataView to parse the proto.')
        return record_to_tensor_tfxio.TFRecordToTensorTFXIO(
            file_pattern=file_pattern,
            saved_decoder_path=data_view_uri,
            telemetry_descriptors=telemetry_descriptors,
            raw_record_column_name=raw_record_column_name)

    raise NotImplementedError(
        'Unsupport payload format: {}'.format(payload_format))