def test_tensorflow_dataset_with_invalid_label_key(self):
   tfxio = record_to_tensor_tfxio.TFRecordToTensorTFXIO(
       self._input_path, self._decoder_path, ["some", "component"])
   label_key = "invalid"
   options = dataset_options.TensorFlowDatasetOptions(
       batch_size=1, shuffle=False, num_epochs=1, label_key=label_key)
   with self.assertRaisesRegex(ValueError, "The `label_key` provided.*"):
     tfxio.TensorFlowDataset(options=options)
 def test_tensorflow_dataset(self):
   tfxio = record_to_tensor_tfxio.TFRecordToTensorTFXIO(
       self._input_path, self._decoder_path, ["some", "component"])
   options = dataset_options.TensorFlowDatasetOptions(
       batch_size=1, shuffle=False, num_epochs=1)
   for i, decoded_tensors_dict in enumerate(
       tfxio.TensorFlowDataset(options=options)):
     for key, tensor in decoded_tensors_dict.items():
       self._AssertSparseTensorEqual(tensor, _RECORDS_AS_TENSORS[i][key])
    def test_simple(self, attach_raw_records):
        raw_record_column_name = "_raw_records" if attach_raw_records else None
        tfxio = record_to_tensor_tfxio.TFRecordToTensorTFXIO(
            self._input_path,
            self._decoder_path,
            _TELEMETRY_DESCRIPTORS,
            raw_record_column_name=raw_record_column_name)
        expected_fields = [
            pa.field("st1", pa.list_(pa.binary())),
            pa.field("st2", pa.list_(pa.binary())),
        ]
        if attach_raw_records:
            raw_record_column_type = (pa.large_list(pa.large_binary())
                                      if tfxio._can_produce_large_types else
                                      pa.list_(pa.binary()))
            expected_fields.append(
                pa.field(raw_record_column_name, raw_record_column_type))
        self.assertTrue(tfxio.ArrowSchema().equals(pa.schema(expected_fields)),
                        tfxio.ArrowSchema())
        self.assertEqual(
            tfxio.TensorRepresentations(), {
                "st1":
                text_format.Parse(
                    """varlen_sparse_tensor { column_name: "st1" }""",
                    schema_pb2.TensorRepresentation()),
                "st2":
                text_format.Parse(
                    """varlen_sparse_tensor { column_name: "st2" }""",
                    schema_pb2.TensorRepresentation())
            })

        tensor_adapter = tfxio.TensorAdapter()
        self.assertEqual(tensor_adapter.TypeSpecs(),
                         _DecoderForTesting().output_type_specs())

        def _assert_fn(list_of_rb):
            self.assertLen(list_of_rb, 1)
            rb = list_of_rb[0]
            self.assertTrue(rb.schema.equals(tfxio.ArrowSchema()))
            tensors = tensor_adapter.ToBatchTensors(rb)
            self.assertLen(tensors, 2)
            for tensor_name in ("st1", "st2"):
                self.assertIn(tensor_name, tensors)
                st = tensors[tensor_name]
                self.assertAllEqual(st.values, _RECORDS)
                self.assertAllEqual(st.indices, [[0, 0], [1, 0]])
                self.assertAllEqual(st.dense_shape, [2, 1])

        p = beam.Pipeline()
        rb_pcoll = p | tfxio.BeamSource(batch_size=len(_RECORDS))
        beam_testing_util.assert_that(rb_pcoll, _assert_fn)
        pipeline_result = p.run()
        pipeline_result.wait_until_finish()
        telemetry_test_util.ValidateMetrics(self, pipeline_result,
                                            _TELEMETRY_DESCRIPTORS, "tensor",
                                            "tfrecords_gzip")
Beispiel #4
0
 def test_get_decode_function(self):
     decoder_path = _write_decoder()
     tfxio = record_to_tensor_tfxio.TFRecordToTensorTFXIO(
         self._input_path, decoder_path, ["some", "component"])
     decode_fn = tfxio.DecodeFunction()
     decoded = decode_fn(tf.constant(_RECORDS))
     for tensor_name in ("st1", "st2"):
         self.assertIn(tensor_name, decoded)
         st = decoded[tensor_name]
         self.assertAllEqual(st.values, _RECORDS)
         self.assertAllEqual(st.indices, [[0, 0], [1, 0]])
         self.assertAllEqual(st.dense_shape, [2, 1])
Beispiel #5
0
 def test_tensorflow_dataset_with_label_key(self):
   decoder_path = _write_decoder()
   tfxio = record_to_tensor_tfxio.TFRecordToTensorTFXIO(
       self._input_path, decoder_path, ["some", "component"])
   label_key = "st1"
   options = dataset_options.TensorFlowDatasetOptions(
       batch_size=1, shuffle=False, num_epochs=1, label_key=label_key)
   for i, (decoded_tensors_dict, label_feature) in enumerate(
       tfxio.TensorFlowDataset(options=options)):
     self._assert_sparse_tensor_equal(
         label_feature, _RECORDS_AS_TENSORS[i][label_key])
     for key, tensor in decoded_tensors_dict.items():
       self._assert_sparse_tensor_equal(tensor, _RECORDS_AS_TENSORS[i][key])
 def test_projected_tensorflow_dataset(self):
   tfxio = record_to_tensor_tfxio.TFRecordToTensorTFXIO(
       self._input_path, self._decoder_path, ["some", "component"])
   feature_name = "st1"
   projected_tfxio = tfxio.Project([feature_name])
   options = dataset_options.TensorFlowDatasetOptions(
       batch_size=1, shuffle=False, num_epochs=1)
   for i, decoded_tensors_dict in enumerate(
       projected_tfxio.TensorFlowDataset(options=options)):
     self.assertIn(feature_name, decoded_tensors_dict)
     self.assertLen(decoded_tensors_dict, 1)
     tensor = decoded_tensors_dict[feature_name]
     self._AssertSparseTensorEqual(tensor,
                                   _RECORDS_AS_TENSORS[i][feature_name])
  def test_project(self):
    tfxio = record_to_tensor_tfxio.TFRecordToTensorTFXIO(
        self._input_path, self._decoder_path, ["some", "component"])
    projected = tfxio.Project(["st1"])
    self.assertIn("st1", projected.TensorRepresentations())
    self.assertNotIn("st2", projected.TensorRepresentations())
    tensor_adapter = projected.TensorAdapter()

    def _assert_fn(list_of_rb):
      self.assertLen(list_of_rb, 1)
      rb = list_of_rb[0]
      tensors = tensor_adapter.ToBatchTensors(rb)
      self.assertLen(tensors, 1)
      self.assertIn("st1", tensors)
      st = tensors["st1"]
      self.assertAllEqual(st.values, _RECORDS)
      self.assertAllEqual(st.indices, [[0, 0], [1, 0]])
      self.assertAllEqual(st.dense_shape, [2, 1])

    with beam.Pipeline() as p:
      rb_pcoll = p | tfxio.BeamSource(batch_size=len(_RECORDS))
      beam_testing_util.assert_that(rb_pcoll, _assert_fn)
Beispiel #8
0
def make_tfxio(
        file_pattern: OneOrMorePatterns,
        telemetry_descriptors: List[str],
        payload_format: Union[str, int],
        data_view_uri: Optional[str] = None,
        schema: Optional[schema_pb2.Schema] = None,
        read_as_raw_records: bool = False,
        raw_record_column_name: Optional[str] = None,
        file_format: Optional[Union[str, List[str]]] = None) -> tfxio.TFXIO:
    """Creates a TFXIO instance that reads `file_pattern`.

  Args:
    file_pattern: the file pattern for the TFXIO to access.
    telemetry_descriptors: A set of descriptors that identify the component that
      is instantiating the TFXIO. These will be used to construct the namespace
      to contain metrics for profiling and are therefore expected to be
      identifiers of the component itself and not individual instances of source
      use.
    payload_format: one of the enums from example_gen_pb2.PayloadFormat (may be
      in string or int form). If None, default to FORMAT_TF_EXAMPLE.
    data_view_uri: uri to a DataView artifact. A DataView is needed in order to
      create a TFXIO for certain payload formats.
    schema: TFMD schema. Note: although optional, some payload formats need a
      schema in order for all TFXIO interfaces (e.g. TensorAdapter()) to work.
      Unless you know what you are doing, always supply a schema.
    read_as_raw_records: If True, ignore the payload type of `examples`. Always
      use RawTfRecord TFXIO.
    raw_record_column_name: If provided, the arrow RecordBatch produced by the
      TFXIO will contain a string column of the given name, and the contents of
      that column will be the raw records. Note that not all TFXIO supports this
      option, and an error will be raised in that case. Required if
      read_as_raw_records == True.
    file_format: file format string for each file_pattern. Only 'tfrecords_gzip'
      is supported for now.

  Returns:
    a TFXIO instance.
  """
    if not isinstance(payload_format, int):
        payload_format = example_gen_pb2.PayloadFormat.Value(payload_format)

    if file_format is not None:
        if type(file_format) is not type(file_pattern):
            raise ValueError(
                f'The type of file_pattern and file_formats should be the same.'
                f'Given: file_pattern={file_pattern}, file_format={file_format}'
            )
        if isinstance(file_format, list):
            if len(file_format) != len(file_pattern):
                raise ValueError(
                    f'The length of file_pattern and file_formats should be the same.'
                    f'Given: file_pattern={file_pattern}, file_format={file_format}'
                )
            else:
                if any(item != 'tfrecords_gzip' for item in file_format):
                    raise NotImplementedError(
                        f'{file_format} is not supported yet.')
        else:  # file_format is str type.
            if file_format != 'tfrecords_gzip':
                raise NotImplementedError(
                    f'{file_format} is not supported yet.')

    if read_as_raw_records:
        assert raw_record_column_name is not None, (
            'read_as_raw_records is specified - '
            'must provide raw_record_column_name')
        return raw_tf_record.RawTfRecordTFXIO(
            file_pattern=file_pattern,
            raw_record_column_name=raw_record_column_name,
            telemetry_descriptors=telemetry_descriptors)

    if payload_format == example_gen_pb2.PayloadFormat.FORMAT_TF_EXAMPLE:
        return tf_example_record.TFExampleRecord(
            file_pattern=file_pattern,
            schema=schema,
            raw_record_column_name=raw_record_column_name,
            telemetry_descriptors=telemetry_descriptors)

    if (payload_format ==
            example_gen_pb2.PayloadFormat.FORMAT_TF_SEQUENCE_EXAMPLE):
        return tf_sequence_example_record.TFSequenceExampleRecord(
            file_pattern=file_pattern,
            schema=schema,
            raw_record_column_name=raw_record_column_name,
            telemetry_descriptors=telemetry_descriptors)

    if payload_format == example_gen_pb2.PayloadFormat.FORMAT_PROTO:
        assert data_view_uri is not None, (
            'Accessing FORMAT_PROTO requires a DataView to parse the proto.')
        return record_to_tensor_tfxio.TFRecordToTensorTFXIO(
            file_pattern=file_pattern,
            saved_decoder_path=data_view_uri,
            telemetry_descriptors=telemetry_descriptors,
            raw_record_column_name=raw_record_column_name)

    raise NotImplementedError(
        'Unsupport payload format: {}'.format(payload_format))
Beispiel #9
0
  def test_beam_source_and_tensor_adapter(
      self, attach_raw_records, create_decoder, beam_record_tfxio=False):
    decoder = create_decoder()
    raw_record_column_name = "_raw_records" if attach_raw_records else None
    decoder_path = _write_decoder(decoder)
    if beam_record_tfxio:
      tfxio = record_to_tensor_tfxio.BeamRecordToTensorTFXIO(
          saved_decoder_path=decoder_path,
          telemetry_descriptors=_TELEMETRY_DESCRIPTORS,
          physical_format="tfrecords_gzip",
          raw_record_column_name=raw_record_column_name)
    else:
      tfxio = record_to_tensor_tfxio.TFRecordToTensorTFXIO(
          self._input_path,
          decoder_path,
          _TELEMETRY_DESCRIPTORS,
          raw_record_column_name=raw_record_column_name)
    expected_tensor_representations = {
        "st1":
            text_format.Parse("""varlen_sparse_tensor { column_name: "st1" }""",
                              schema_pb2.TensorRepresentation()),
        "st2":
            text_format.Parse("""varlen_sparse_tensor { column_name: "st2" }""",
                              schema_pb2.TensorRepresentation())
    }
    if isinstance(decoder, _DecoderForTestingWithRecordIndex):
      expected_fields = [
          pa.field("ragged_record_index", pa.large_list(pa.int64())),
          pa.field("sparse_record_index", pa.large_list(pa.int64())),
          pa.field("st1", pa.large_list(pa.large_binary())),
          pa.field("st2", pa.large_list(pa.large_binary())),
      ]
      expected_tensor_representations["ragged_record_index"] = (
          text_format.Parse(
              """ragged_tensor {
                   feature_path: { step: "ragged_record_index" }
                   row_partition_dtype: INT64
                 }""", schema_pb2.TensorRepresentation()))
      expected_tensor_representations["sparse_record_index"] = (
          text_format.Parse(
              """varlen_sparse_tensor { column_name: "sparse_record_index" }""",
              schema_pb2.TensorRepresentation()))
    else:
      expected_fields = [
          pa.field("st1", pa.large_list(pa.large_binary())),
          pa.field("st2", pa.large_list(pa.large_binary())),
      ]
    if attach_raw_records:
      expected_fields.append(
          pa.field(raw_record_column_name, pa.large_list(pa.large_binary())))
    self.assertTrue(tfxio.ArrowSchema().equals(
        pa.schema(expected_fields)), tfxio.ArrowSchema())

    self.assertEqual(
        tfxio.TensorRepresentations(), expected_tensor_representations)

    tensor_adapter = tfxio.TensorAdapter()
    self.assertEqual(tensor_adapter.TypeSpecs(),
                     decoder.output_type_specs())

    def _assert_fn(list_of_rb):
      self.assertLen(list_of_rb, 1)
      rb = list_of_rb[0]
      self.assertTrue(rb.schema.equals(tfxio.ArrowSchema()))
      if attach_raw_records:
        self.assertEqual(rb.column(rb.num_columns - 1).flatten().to_pylist(),
                         _RECORDS)
      tensors = tensor_adapter.ToBatchTensors(rb)
      for tensor_name in ("st1", "st2"):
        self.assertIn(tensor_name, tensors)
        st = tensors[tensor_name]
        self.assertAllEqual(st.values, _RECORDS)
        self.assertAllEqual(st.indices, [[0, 0], [1, 0]])
        self.assertAllEqual(st.dense_shape, [2, 1])

    p = beam.Pipeline()
    pipeline_input = (p | beam.Create(_RECORDS)) if beam_record_tfxio else p
    rb_pcoll = pipeline_input | tfxio.BeamSource(batch_size=len(_RECORDS))
    beam_testing_util.assert_that(rb_pcoll, _assert_fn)
    pipeline_result = p.run()
    pipeline_result.wait_until_finish()
    telemetry_test_util.ValidateMetrics(
        self, pipeline_result, _TELEMETRY_DESCRIPTORS,
        "tensor", "tfrecords_gzip")