Ejemplo n.º 1
0
    def testReadTfRecord(self):
        tmp_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir)

        def _WriteTfRecord(path, records):
            with tf.io.TFRecordWriter(path) as w:
                for r in records:
                    w.write(r)

        file1 = os.path.join(tmp_dir, "tfrecord1")
        file1_records = [b"aa", b"bb"]
        _WriteTfRecord(file1, file1_records)
        file2 = os.path.join(tmp_dir, "tfrecord2")
        file2_records = [b"cc", b"dd"]
        _WriteTfRecord(file2, file2_records)

        def _CheckRecords(actual, expected):
            self.assertEqual(set(actual), set(expected))

        # Test reading multiple file patterns.
        with beam.Pipeline() as p:
            record_pcoll = p | record_based_tfxio.ReadTfRecord(
                [file1 + "*", file2 + "*"])
            beam_test_util.assert_that(
                record_pcoll, lambda actual: _CheckRecords(
                    actual, file1_records + file2_records))
Ejemplo n.º 2
0
 def _RawRecordBeamSourceInternal(self) -> beam.PTransform:
     return record_based_tfxio.ReadTfRecord(self._file_pattern)